diff --git a/Project.toml b/Project.toml index 6ef941bb..56a4b19a 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" keywords = ["markov chain monte carlo", "probablistic programming"] license = "MIT" desc = "Chain types and utility functions for MCMC simulations." -version = "4.13.1" +version = "4.15.0" [deps] AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c" diff --git a/src/plot.jl b/src/plot.jl index 41edb8f2..fc1e3232 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -4,12 +4,14 @@ @shorthands pooleddensity @shorthands traceplot @shorthands corner +@shorthands violinplot struct _TracePlot; c; val; end struct _MeanPlot; c; val; end struct _DensityPlot; c; val; end struct _HistogramPlot; c; val; end struct _AutocorPlot; lags; val; end +struct _ViolinPlot; par; val; end # define alias functions for old syntax const translationdict = Dict( @@ -18,7 +20,8 @@ const translationdict = Dict( :density => _DensityPlot, :histogram => _HistogramPlot, :autocorplot => _AutocorPlot, - :pooleddensity => _DensityPlot + :pooleddensity => _DensityPlot, + :violinplot => _ViolinPlot ) const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :corner) @@ -30,7 +33,9 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor colordim = :chain, barbounds = (-Inf, Inf), maxlag = nothing, - append_chains = false + append_chains = false, + par_sections = chains.name_map[:parameters], + combined = true ) st = get(plotattributes, :seriestype, :traceplot) c = append_chains || st == :pooleddensity ? pool_chain(chains) : chains @@ -64,6 +69,33 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor ac_mat = convert(Array, ac) val = colordim == :parameter ? ac_mat[:, :, i]' : ac_mat[i, :, :] _AutocorPlot(lags, val) + + elseif st == :violinplot + n_iter, n_par, n_chains = size(chains) + if combined + colordim := :chain + par = string.(reshape(repeat(par_sections, inner = n_iter), n_iter, n_par))[:,i] + val = Array(chains)[:,i] + _ViolinPlot(par, val) + elseif combined == false + if colordim == :chain + par_names = ["$(par_sections[i]).Chain $j" for i in 1:n_par, j in 1:n_chains] + pars = string.(reshape(repeat(vec(par_names), inner = n_iter), (n_iter, n_par, n_chains))) + val = chains.value[:,i,:] + par = pars[:,i,:] + elseif colordim == :parameter + par_vec = repeat(par_sections, inner = n_iter) + pars = string.(reshape(repeat(par_vec, n_chains, 1), (n_iter, n_par, n_chains))) + val = chains.value[:,:,i] + par = pars[:,:,i] + label --> string.(names(c)) + else + throw(ArgumentError("`colordim` must be one of `:chain` or `:parameter`")) + end + _ViolinPlot(par, val) + else + throw(ArgumentError("In `ViolinPlots` `Chains` can be combined or separated ")) + end elseif st ∈ supportedplots translationdict[st](c, val) else @@ -184,3 +216,18 @@ end ar = collect(Array(corner.c.value[:, corner.parameters,i]) for i in chains(corner.c)) RecipesBase.recipetype(:cornerplot, vcat(ar...)) end + +@recipe function f(p::_ViolinPlot) + @series begin + seriestype := :violin + p.par, p.val + end + + @series begin + seriestype := :boxplot + bar_width --> 0.1 + linewidth --> 2 + fillalpha --> 0.8 + p.par, p.val + end +end diff --git a/test/plot_test.jl b/test/plot_test.jl index 9654ed8c..bb142e5d 100644 --- a/test/plot_test.jl +++ b/test/plot_test.jl @@ -24,29 +24,32 @@ Logging.disable_logging(Logging.Warn) println("traceplot") display(traceplot(chn, 1)) println() - + println("meanplot") display(meanplot(chn, 1)) println() - + println("density") display(density(chn, 1)) display(density(chn, 1, append_chains=true)) println() - + println("autocorplot") display(autocorplot(chn, 1)) println() - + #ps_contour = plot(chn, :contour) println("histogram") display(histogram(chn, 1)) println() - + println("\nmixeddensity") display(mixeddensity(chn, 1)) - + + println("violinplot") + display(violinplot(chn)) + println() # plotting combinations display(plot(chn)) display(plot(chn, append_chains=true)) @@ -54,7 +57,7 @@ Logging.disable_logging(Logging.Warn) # Test plotting using colordim keyword display(plot(chn, colordim = :parameter)) - + # Test if plotting a sub-set work.s display(plot(chn, 2)) display(plot(chn, 2, colordim = :parameter))