diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 6301c7e..32dad55 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -21,7 +21,7 @@ jobs: - '1.7' - '1.8' - '1.9' - - 'nightly' + - '1.10' os: - ubuntu-latest arch: diff --git a/Project.toml b/Project.toml index c368817..ff6d8eb 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "TaijaPlotting" uuid = "bd7198b4-c7d6-400c-9bab-9a24614b0240" authors = ["Patrick Altmeyer"] -version = "1.0.5" +version = "1.0.6" [deps] CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597" @@ -13,6 +13,7 @@ Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" LaplaceRedux = "c52c1a26-f7c5-402b-80be-ba1e638ad478" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d" +MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54" ManifoldLearning = "06eb3307-b2af-5a2a-abea-d33192699d32" MultivariateStats = "6f286f6a-111f-5878-ab1e-185364afe411" NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85" @@ -30,11 +31,12 @@ LaplaceRedux = "0.1" LinearAlgebra = "1.7, 1.8, 1.9" MLJBase = "0.21, 0.22, 1" ManifoldLearning = "0.9" +MLUtils = "0.4" MultivariateStats = "0.9, 0.10" NaturalSort = "1" NearestNeighborModels = "0.2" Plots = "1" -julia = "1.7, 1.8, 1.9" +julia = "1.7, 1.8, 1.9, 1.10" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/CounterfactualExplations/CounterfactualExplanations.jl b/src/CounterfactualExplations/CounterfactualExplanations.jl index f43f2b9..30d2b54 100644 --- a/src/CounterfactualExplations/CounterfactualExplanations.jl +++ b/src/CounterfactualExplations/CounterfactualExplanations.jl @@ -4,4 +4,4 @@ using CounterfactualExplanations.Models include("data.jl") include("models.jl") -include("countefactuals.jl") \ No newline at end of file +include("counterfactuals.jl") \ No newline at end of file diff --git a/src/CounterfactualExplations/countefactuals.jl b/src/CounterfactualExplations/counterfactuals.jl similarity index 96% rename from src/CounterfactualExplations/countefactuals.jl rename to src/CounterfactualExplations/counterfactuals.jl index 91de5f4..b6eabc6 100644 --- a/src/CounterfactualExplations/countefactuals.jl +++ b/src/CounterfactualExplations/counterfactuals.jl @@ -1,3 +1,5 @@ +using MLUtils: stack + """ Plots.plot( ce::CounterfactualExplanation; @@ -103,21 +105,21 @@ end plot_state( ce::CounterfactualExplanation, t::Int, - final_sate::Bool; + final_state::Bool; kwargs... ) Helper function that plots a single step of the counterfactual path. """ -function plot_state(ce::CounterfactualExplanation, t::Int, final_sate::Bool; kwargs...) +function plot_state(ce::CounterfactualExplanation, t::Int, final_state::Bool; kwargs...) args = PlotIngredients(; kwargs...) - x1 = [args.path_embedded[1,t]] - x2 = [args.path_embedded[2,t]] + x1 = args.path_embedded[1,t,:] + x2 = args.path_embedded[2,t,:] y = args.path_labels[t] _c = CategoricalArrays.levelcode.(y) n_ = ce.num_counterfactuals label_ = reshape(["C$i" for i in 1:n_], 1, n_) - if !final_sate + if !final_state scatter!(args.p1, x1, x2; group=y, colour=_c, ms=5, label="") else scatter!(args.p1, x1, x2; group=y, colour=_c, ms=10, label="") diff --git a/src/CounterfactualExplations/data.jl b/src/CounterfactualExplations/data.jl index c29362a..ae96509 100644 --- a/src/CounterfactualExplations/data.jl +++ b/src/CounterfactualExplations/data.jl @@ -22,7 +22,7 @@ function embed(data::CounterfactualData, X::AbstractArray=nothing; dim_red::Symb end # Transforming: - X = typeof(X) <: Vector{<:Matrix} ? hcat(X...) : X + X = typeof(X) <: Vector{<:Matrix} ? stack(X, dims=2) : X if !isnothing(tfn) && !isnothing(X) X = MultivariateStats.predict(tfn, X) else