diff --git a/src/plot.jl b/src/plot.jl index d6cfbb6f..c95e6798 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -11,7 +11,7 @@ struct _MeanPlot; c; val; end struct _DensityPlot; c; val; end struct _HistogramPlot; c; val; end struct _AutocorPlot; lags; val; end -struct _ViolinPlot; parameters; val; end +struct _ViolinPlot; parameters; val; total_chains; end # define alias functions for old syntax const translationdict = Dict( @@ -195,27 +195,28 @@ end ) st = get(plotattributes, :seriestype, :traceplot) + total_chains = 0 if st == :violinplot if combined parameters = string.(sections) val = Array(chains)[:, ] - _ViolinPlot(parameters, val) - + total_chains = Integer(size(chains.value.data)[3]) + _ViolinPlot(parameters, val, total_chains) elseif combined == false - data = Array(chains, append_chains = false) - parameters = vec(["param $(sections[i]).Chain $j" - for i in 1:length(sections), - j in 1:length(data)]) - val_vec = vec([data[j][:,i] for i in 1:length(sections), j in 1:length(data)]) + chain_arr = Array(chains, append_chains = false) + parameters = ["param $(sections[i]).Chain $j" + for i in 1:length(sections) + for j in 1:length(chain_arr)] + val_vec = [chain_arr[j][:,i] + for i in 1:length(sections) + for j in 1:length(chain_arr)] n_iter = length(val_vec[1]) - n_chains = length(val_vec) - val = zeros(Float64, n_iter, n_chains) - for i in 1:n_iter - for j in 1:n_chains - val[i,j] = val_vec[j][i] - end + total_chains = length(val_vec) + val = zeros(Float64, n_iter, total_chains) + for i in 1:total_chains + val[:,i] = val_vec[:][i] end - _ViolinPlot(parameters, val[:,]) + _ViolinPlot(parameters, val[:,], total_chains) else error("Symbol names are interpreted as parameter names, only compatible with ", "`colordim = :chain`") @@ -224,8 +225,18 @@ end end @recipe function f(p::_ViolinPlot) - seriestype := :violin - xaxis --> "Parameter" - p.parameters, p.val - #[collect(skipmissing(p.val[:,k])) for k in 1:size(p.val)] + @series begin + seriestype := :violin + xaxis --> "Parameter" + size --> (150*p.total_chains, 500) + p.parameters, p.val + end + + @series begin + seriestype := :boxplot + bar_width --> 0.1 + linewidth --> 2 + fillalpha --> 0.8 + p.parameters, p.val + end end