Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Violin plot #316

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ uuid = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
keywords = ["markov chain monte carlo", "probablistic programming"]
license = "MIT"
desc = "Chain types and utility functions for MCMC simulations."
version = "4.13.1"
version = "4.15.0"

[deps]
AbstractFFTs = "621f4979-c628-5d54-868e-fcf4e3e8185c"
Expand Down
51 changes: 49 additions & 2 deletions src/plot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,14 @@
@shorthands pooleddensity
@shorthands traceplot
@shorthands corner
@shorthands violinplot

struct _TracePlot; c; val; end
struct _MeanPlot; c; val; end
struct _DensityPlot; c; val; end
struct _HistogramPlot; c; val; end
struct _AutocorPlot; lags; val; end
struct _ViolinPlot; par; val; end

# define alias functions for old syntax
const translationdict = Dict(
Expand All @@ -18,7 +20,8 @@ const translationdict = Dict(
:density => _DensityPlot,
:histogram => _HistogramPlot,
:autocorplot => _AutocorPlot,
:pooleddensity => _DensityPlot
:pooleddensity => _DensityPlot,
:violinplot => _ViolinPlot
)

const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :corner)
Expand All @@ -30,7 +33,9 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor
colordim = :chain,
barbounds = (-Inf, Inf),
maxlag = nothing,
append_chains = false
append_chains = false,
par_sections = chains.name_map[:parameters],
combined = true
)
st = get(plotattributes, :seriestype, :traceplot)
c = append_chains || st == :pooleddensity ? pool_chain(chains) : chains
Expand Down Expand Up @@ -64,6 +69,33 @@ const supportedplots = push!(collect(keys(translationdict)), :mixeddensity, :cor
ac_mat = convert(Array, ac)
val = colordim == :parameter ? ac_mat[:, :, i]' : ac_mat[i, :, :]
_AutocorPlot(lags, val)

elseif st == :violinplot
n_iter, n_par, n_chains = size(chains)
if combined
colordim := :chain
par = string.(reshape(repeat(par_sections, inner = n_iter), n_iter, n_par))[:,i]
val = Array(chains)[:,i]
_ViolinPlot(par, val)
elseif combined == false
if colordim == :chain
par_names = ["$(par_sections[i]).Chain $j" for i in 1:n_par, j in 1:n_chains]
pars = string.(reshape(repeat(vec(par_names), inner = n_iter), (n_iter, n_par, n_chains)))
val = chains.value[:,i,:]
par = pars[:,i,:]
elseif colordim == :parameter
par_vec = repeat(par_sections, inner = n_iter)
pars = string.(reshape(repeat(par_vec, n_chains, 1), (n_iter, n_par, n_chains)))
val = chains.value[:,:,i]
par = pars[:,:,i]
label --> string.(names(c))
else
throw(ArgumentError("`colordim` must be one of `:chain` or `:parameter`"))
end
_ViolinPlot(par, val)
else
throw(ArgumentError("In `ViolinPlots` `Chains` can be combined or separated "))
end
elseif st ∈ supportedplots
translationdict[st](c, val)
else
Expand Down Expand Up @@ -184,3 +216,18 @@ end
ar = collect(Array(corner.c.value[:, corner.parameters,i]) for i in chains(corner.c))
RecipesBase.recipetype(:cornerplot, vcat(ar...))
end

@recipe function f(p::_ViolinPlot)
@series begin
seriestype := :violin
p.par, p.val
end

@series begin
seriestype := :boxplot
bar_width --> 0.1
linewidth --> 2
fillalpha --> 0.8
p.par, p.val
end
end
17 changes: 10 additions & 7 deletions test/plot_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,37 +24,40 @@ Logging.disable_logging(Logging.Warn)
println("traceplot")
display(traceplot(chn, 1))
println()

println("meanplot")
display(meanplot(chn, 1))
println()

println("density")
display(density(chn, 1))
display(density(chn, 1, append_chains=true))
println()

println("autocorplot")
display(autocorplot(chn, 1))
println()

#ps_contour = plot(chn, :contour)

println("histogram")
display(histogram(chn, 1))
println()

println("\nmixeddensity")
display(mixeddensity(chn, 1))


println("violinplot")
display(violinplot(chn))
println()
# plotting combinations
display(plot(chn))
display(plot(chn, append_chains=true))
display(plot(chn, seriestype = (:mixeddensity, :autocorplot)))

# Test plotting using colordim keyword
display(plot(chn, colordim = :parameter))

# Test if plotting a sub-set work.s
display(plot(chn, 2))
display(plot(chn, 2, colordim = :parameter))
Expand Down