Skip to content

Commit

Permalink
Merge branch 'pm/violin_plot' of https://github.com/PaulinaMartin96/M…
Browse files Browse the repository at this point in the history
…CMCChains.jl into pm/violin_plot
  • Loading branch information
PaulinaMartin96 committed Jul 29, 2021
2 parents 6669d0a + cf89757 commit 3d14727
Show file tree
Hide file tree
Showing 12 changed files with 102 additions and 48 deletions.
7 changes: 4 additions & 3 deletions docs/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,16 @@ CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
Gadfly = "c91e804a-d5a3-530f-b6f0-dfbca275c004"
MLJModels = "d491faf4-2d78-11e9-2867-c94bc002c0b7"
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
MLJXGBoostInterface = "54119dfa-1dab-4055-a167-80440f4f7a91"
StatsPlots = "f3b207a7-027a-5e70-b257-86293d7955fd"

[compat]
CategoricalArrays = "0.8, 0.9, 0.10"
DataFrames = "0.22, 1"
Documenter = "0.26"
Documenter = "0.26, 0.27"
Gadfly = "1.3"
MLJModels = "0.14"
MLJBase = "0.18"
MLJXGBoostInterface = "0.1"
StatsPlots = "0.14"
julia = "1.3"
35 changes: 27 additions & 8 deletions src/chains.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,22 @@ function Chains(
name_map = (parameters = parameter_names,);
start::Int = 1,
thin::Int = 1,
iterations::AbstractVector{Int} = range(start; step=thin, length=size(val, 1)),
evidence = missing,
info::NamedTuple = NamedTuple()
)
# Check that iteration numbers are reasonable
if length(iterations) != size(val, 1)
error("length of `iterations` (", length(iterations),
") is not equal to the number of iterations (", size(val, 1), ")")
end
if !isempty(iterations) && first(iterations) < 1
error("iteration numbers must be positive integers")
end
if !isstrictlyincreasing(iterations)
error("iteration numbers must be strictly increasing")
end

# Make sure that we have a `:parameters` index and # Copying can avoid state mutation.
_name_map = initnamemap(name_map)

Expand All @@ -58,7 +71,7 @@ function Chains(

# Construct the AxisArray.
arr = AxisArray(val;
iter = range(start, step=thin, length=size(val, 1)),
iter = iterations,
var = parameter_names,
chain = 1:size(val, 3))

Expand Down Expand Up @@ -444,17 +457,21 @@ Return the range of iteration indices of the `chains`.
Base.range(chains::Chains) = chains.value[Axis{:iter}].val

"""
setrange(chains::Chains, range)
setrange(chains::Chains, range::AbstractVector{Int})
Generate a new chain from `chains` with iterations indexed by `range`.
The new chain and `chains` share the same data in memory.
"""
function setrange(chains::Chains, range::AbstractRange{<:Integer})
function setrange(chains::Chains, range::AbstractVector{Int})
if length(chains) != length(range)
error("length of `range` (", length(range),
") is not equal to the number of iterations (", length(chains), ")")
end
if !isempty(range) && first(range) < 1
error("iteration numbers must be positive integers")
end
isstrictlyincreasing(range) || error("iteration numbers must be strictly increasing")

value = AxisArray(chains.value.data;
iter = range, var = names(chains), chain = MCMCChains.chains(chains))
Expand Down Expand Up @@ -574,8 +591,7 @@ function header(c::Chains; section=missing)
# Return header.
return string(
ismissing(c.logevidence) ? "" : "Log evidence = $(c.logevidence)\n",
"Iterations = $(first(c)):$(last(c))\n",
"Thinning interval = $(step(c))\n",
"Iterations = $(range(c))\n",
"Number of chains = $(size(c, 3))\n",
"Samples per chain = $(length(range(c)))\n",
ismissing(wall) ? "" : "Wall duration = $(round(wall, digits=2)) seconds\n",
Expand Down Expand Up @@ -725,8 +741,11 @@ _cat(dim::Int, cs::Chains...) = _cat(Val(dim), cs...)

function _cat(::Val{1}, c1::Chains, args::Chains...)
# check inputs
thin = step(c1)
all(c -> step(c) == thin, args) || throw(ArgumentError("chain thinning differs"))
lastiter = last(c1)
for c in args
first(c) > lastiter || throw(ArgumentError("iterations have to be sorted"))
lastiter = last(c)
end
nms = names(c1)
all(c -> names(c) == nms, args) || throw(ArgumentError("chain names differ"))
chns = chains(c1)
Expand All @@ -735,7 +754,7 @@ function _cat(::Val{1}, c1::Chains, args::Chains...)
# concatenate all chains
data = mapreduce(c -> c.value.data, vcat, args; init = c1.value.data)
value = AxisArray(data;
iter = range(first(c1); length = size(data, 1), step = thin),
iter = mapreduce(range, vcat, args; init=range(c1)),
var = nms,
chain = chns)

Expand Down
4 changes: 2 additions & 2 deletions src/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ function Base.Array(
end

function to_matrix(chain::Chains)
return Matrix(reshape(permutedims(chain.value.data, (1, 3, 2)), :, size(chain, 2)))
x = permutedims(chain.value.data, (1, 3, 2))
return Matrix(reshape(x, size(x, 1) * size(x, 2), size(x, 3)))
end

function to_vector(chain::Chains)
Expand All @@ -79,4 +80,3 @@ function to_vector_of_matrices(chain::Chains)
data = chain.value.data
return [Matrix(data[:, :, i]) for i in axes(data, 3)]
end

2 changes: 1 addition & 1 deletion src/fileio.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,5 +22,5 @@ function readcoda(output::AbstractString, index::AbstractString)
value[:, i] = out[inds, 2]
end

Chains(value, start=first(window), thin=step(window), names=names)
Chains(value; iterations=window, names=names)
end
5 changes: 2 additions & 3 deletions src/rstar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ verbosity level.
# Example
```jldoctest rstar; output = false, filter = r".*"s
using MLJModels
using MLJBase, MLJXGBoostInterface
XGBoost = @load XGBoostClassifier verbosity=0
chn = Chains(fill(4, 100, 2, 3))
Rs = rstar(XGBoost(), chn; iterations=20)
Rs = rstar(XGBoostClassifier(), chn; iterations=20)
R = round(mean(Rs); digits=0)
# output
Expand Down
22 changes: 18 additions & 4 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -96,7 +96,7 @@ function merge_union(a::NamedTuple{an}, b::NamedTuple{bn}) where {an,bn}
:(getfield(b, $(QuoteNode(n))))
end
end

return :(NamedTuple{$names,$types}(($(values...),)))
else
names = Base.merge_names(an, bn)
Expand All @@ -113,7 +113,7 @@ function merge_union(a::NamedTuple{an}, b::NamedTuple{bn}) where {an,bn}
getfield(b, n)
end
end

return NamedTuple{names,types}(values)
end
end
Expand Down Expand Up @@ -179,8 +179,8 @@ function concretize(x::AbstractArray)
return x
else
xnew = map(concretize, x)
T = mapreduce(typeof, promote_type, xnew)
if T <: eltype(xnew)
T = mapreduce(typeof, promote_type, xnew; init=Union{})
if T <: eltype(xnew) && T !== Union{}
return convert(AbstractArray{T}, xnew)
else
return xnew
Expand All @@ -196,3 +196,17 @@ function concretize(x::Chains)
return Chains(concretize(value), x.logevidence, x.name_map, x.info)
end
end

function isstrictlyincreasing(x::AbstractVector{Int})
return isempty(x) || _isstrictlyincreasing_nonempty(x)
end

_isstrictlyincreasing_nonempty(x::AbstractRange{Int}) = step(x) > 0
function _isstrictlyincreasing_nonempty(x::AbstractVector{Int})
i = first(x)
for j in Iterators.drop(x, 1)
j > i || return false
i = j
end
return true
end
2 changes: 1 addition & 1 deletion test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
AbstractMCMC = "2.2.1, 3.0"
DataFrames = "0.22.4, 1.0"
Distributions = "0.24.12, 0.25"
Documenter = "0.26"
Documenter = "0.26, 0.27"
FFTW = "1.1"
IteratorInterfaceExtensions = "1"
KernelDensity = "0.6.2"
Expand Down
5 changes: 5 additions & 0 deletions test/arrayconstructor_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,11 @@ using MCMCChains, Test
Array(chns[:a])
Array(chns, [:parameters])
Array(chns, [:parameters, :internals])

# empty chain: #317
empty_chain = chns[Symbol[]]
@test isempty(MCMCChains.to_matrix(empty_chain))
@test isempty(Array(empty_chain))
end
@testset "Accuracy" begin
nchains = 5
Expand Down
30 changes: 16 additions & 14 deletions test/concatenation_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,58 +46,60 @@ end
chn = Chains(rand(10, 5, 2), ["a", "b", "c", "d", "e"], Dict(:internal => ["d", "e"]))
chn1 = Chains(rand(5, 5, 2), ["a", "b", "c", "d", "e"], Dict(:internal => ["a", "b"]))

# incorrect thinning
@test_throws ArgumentError vcat(chn, Chains(rand(2, 5, 2); thin = 2))
# incorrect iterations
@test_throws ArgumentError vcat(chn, Chains(rand(2, 5, 2)))

# incorrect names
@test_throws ArgumentError vcat(chn, Chains(rand(10, 5, 2), ["a", "b", "c", "d", "f"]))
@test_throws ArgumentError vcat(chn, Chains(rand(10, 5, 2), ["a", "b", "c", "d", "f"]; start=11))

# incorrect number of chains
@test_throws ArgumentError vcat(chn, Chains(rand(10, 5, 3), ["a", "b", "c", "d", "e"]))
@test_throws ArgumentError vcat(chn, Chains(rand(10, 5, 3), ["a", "b", "c", "d", "e"]; start=11))

# concate the same chain
chn2 = vcat(chn, chn)
chn_shifted = setrange(chn, 11:20)
chn2 = vcat(chn, chn_shifted)
@test chn2.value.data == vcat(chn.value.data, chn.value.data)
@test size(chn2) == (20, 5, 2)
@test names(chn2) == names(chn)
@test range(chn2) == 1:20
@test chn2.name_map == (parameters = [:a, :b, :c], internal = [:d, :e])
chn2a = cat(chn, chn)

chn2a = cat(chn, chn_shifted)
@test chn2a.value == chn2.value
@test chn2a.name_map == chn2.name_map
@test chn2a.info == chn2.info

chn2b = cat(chn, chn; dims = Val(1))
chn2b = cat(chn, chn_shifted; dims = Val(1))
@test chn2b.value == chn2.value
@test chn2b.name_map == chn2.name_map
@test chn2b.info == chn2.info

chn2c = cat(chn, chn; dims = 1)
chn2c = cat(chn, chn_shifted; dims = 1)
@test chn2c.value == chn2.value
@test chn2c.name_map == chn2.name_map
@test chn2c.info == chn2.info

# concatenate a different chain
chn3 = vcat(chn, chn1)
chn1_shifted = setrange(chn1, 11:15)
chn3 = vcat(chn, chn1_shifted)
@test chn3.value.data == vcat(chn.value.data, chn1.value.data)
@test size(chn3) == (15, 5, 2)
@test names(chn3) == names(chn)
@test range(chn3) == 1:15
# just take the name map of first argument
@test chn3.name_map == (parameters = [:a, :b, :c], internal = [:d, :e])
chn3a = cat(chn, chn1)

chn3a = cat(chn, chn1_shifted)
@test chn3a.value == chn3.value
@test chn3a.name_map == chn3.name_map
@test chn3a.info == chn3.info

chn3b = cat(chn, chn1; dims = Val(1))
chn3b = cat(chn, chn1_shifted; dims = Val(1))
@test chn3b.value == chn3.value
@test chn3b.name_map == chn3.name_map
@test chn3b.info == chn3.info

chn3c = cat(chn, chn1; dims = 1)
chn3c = cat(chn, chn1_shifted; dims = 1)
@test chn3c.value == chn3.value
@test chn3c.name_map == chn3.name_map
@test chn3c.info == chn3.info
Expand Down
30 changes: 22 additions & 8 deletions test/diagnostic_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,12 @@ val = hcat(val, rand(1:2, niter, 1, nchains))

# construct a Chains object
chn = Chains(val, start = 1, thin = 2)
@test_throws ErrorException Chains(val; start=0, thin=2)
@test_throws ErrorException Chains(val; start=niter, thin=-1)
@test_throws ErrorException Chains(val; iterations=1:(niter - 1))
@test_throws ErrorException Chains(val; iterations=range(0; step=2, length=niter))
@test_throws ErrorException Chains(val; iterations=niter:-1:1)
@test_throws ErrorException Chains(val; iterations=ones(Int, niter))

# Chains object for discretediag
val_disc = rand(Int16, 200, nparams, nchains)
Expand All @@ -29,18 +35,26 @@ chn_disc = Chains(val_disc, start = 1, thin = 2)
@test keys(chn) == names(chn) == [:param_1, :param_2, :param_3, :param_4]

@test range(chn) == range(1; step = 2, length = niter)
@test range(chn) == range(Chains(val; iterations=range(chn)))
@test range(chn) == range(Chains(val; iterations=collect(range(chn))))

@test_throws ErrorException setrange(chn, 1:10)
@test_throws ErrorException setrange(chn, 0:(niter - 1))
@test_throws ErrorException setrange(chn, niter:-1:1)
@test_throws ErrorException setrange(chn, ones(Int, niter))
@test_throws MethodError setrange(chn, float.(range(chn)))

chn2 = setrange(chn, range(1; step = 10, length = niter))
@test range(chn2) == range(1; step = 10, length = niter)
@test names(chn2) === names(chn)
@test chains(chn2) === chains(chn)
@test chn2.value.data === chn.value.data
@test chn2.logevidence === chn.logevidence
@test chn2.name_map === chn.name_map
@test chn2.info == chn.info
chn2a = setrange(chn, range(1; step = 10, length = niter))
chn2b = setrange(chn, collect(range(1; step = 10, length = niter)))
for chn2 in (chn2a, chn2b)
@test range(chn2) == range(1; step = 10, length = niter)
@test names(chn2) === names(chn)
@test chains(chn2) === chains(chn)
@test chn2.value.data === chn.value.data
@test chn2.logevidence === chn.logevidence
@test chn2.name_map === chn.name_map
@test chn2.info == chn.info
end

chn3 = resetrange(chn)
@test range(chn3) == 1:niter
Expand Down
6 changes: 3 additions & 3 deletions test/rstar_tests.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using MCMCChains
using MLJModels
using MLJBase
using MLJXGBoostInterface
using Test

N = 1000
Expand All @@ -8,8 +9,7 @@ colnames = ["a", "b", "c", "d", "e", "f", "g", "h"]
internal_colnames = ["c", "d", "e", "f", "g", "h"]
chn = Chains(val, colnames, Dict(:internals => internal_colnames))

XGBoost = @load XGBoostClassifier
classif = XGBoost()
classif = XGBoostClassifier()

@testset "R star test" begin
# Compute R* statistic for a mixed chain.
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ Random.seed!(0)
if VERSION >= v"1.3" && Sys.WORD_SIZE == 64
# run tests related to rstar statistic
println("Rstar")
Pkg.add("MLJModels")
Pkg.add("MLJBase")
Pkg.add("MLJXGBoostInterface")
@time include("rstar_tests.jl")

Expand Down

0 comments on commit 3d14727

Please sign in to comment.