Skip to content

Commit

Permalink
Try #147:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Aug 11, 2020
2 parents bde1f74 + 7a64232 commit 63543dd
Show file tree
Hide file tree
Showing 6 changed files with 158 additions and 32 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
NaturalSort = "c020b1a1-e9b0-503a-9c33-f039bfc54a85"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

Expand All @@ -16,6 +17,7 @@ AbstractMCMC = "1"
Bijectors = "0.5.2, 0.6, 0.7, 0.8"
Distributions = "0.22, 0.23"
MacroTools = "0.5.1"
NaturalSort = "1"
ZygoteRules = "0.2"
julia = "1"

Expand Down
1 change: 1 addition & 0 deletions src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ using Distributions
using Bijectors

import AbstractMCMC
import NaturalSort
import MacroTools
import ZygoteRules

Expand Down
20 changes: 4 additions & 16 deletions src/prob_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -190,12 +190,13 @@ function Distributions.loglikelihood(
# Element-wise likelihood for each value in chain
chain = right.chain
ctx = LikelihoodContext()
return map(1:length(chain)) do i
c = chain[i]
_setval!(vi, c)
iters = Iterators.product(1:size(chain, 1), 1:size(chain, 3))
logps = map(iters) do (sample_idx, chain_idx)
setval!(vi, chain, sample_idx, chain_idx)
model(vi, SampleFromPrior(), ctx)
return getlogp(vi)
end
return reshape(logps, size(chain, 1), size(chain, 3))
else
# Likelihood without chain
# Rhs values are used in the context
Expand Down Expand Up @@ -231,16 +232,3 @@ end
return :(Model{$(Tuple(missings))}(model.f, $(to_namedtuple_expr(argnames, argvals)),
model.defaults))
end

_setval!(vi::TypedVarInfo, c::AbstractChains) = _setval!(vi.metadata, vi, c)
@generated function _setval!(md::NamedTuple{names}, vi, c) where {names}
return Expr(:block, map(names) do n
quote
for vn in md.$n.vns
val = vec(c[Symbol(vn)])
setval!(vi, val, vn)
settrans!(vi, false, vn)
end
end
end...)
end
48 changes: 47 additions & 1 deletion src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ Set the value(s) of `vn` in the metadata of `vi` to `val`.
The values may or may not be transformed to Euclidean space.
"""
setval!(vi::VarInfo, val, vn::VarName) = getmetadata(vi, vn).vals[getrange(vi, vn)] = val
setval!(vi::VarInfo, val, vn::VarName) = getmetadata(vi, vn).vals[getrange(vi, vn)] = [val;]

"""
getval(vi::VarInfo, vns::Vector{<:VarName})
Expand Down Expand Up @@ -1144,3 +1144,49 @@ function updategid!(vi::AbstractVarInfo, vn::VarName, spl::Sampler)
setgid!(vi, spl.selector, vn)
end
end

setval!(vi::AbstractVarInfo, x) = _setval!(vi, values(x), keys(x))
function setval!(vi::AbstractVarInfo, chains::AbstractChains, sample_idx::Int, chain_idx::Int)
return _setval!(vi, chains.value[sample_idx, :, chain_idx], keys(chains))
end

function _setval!(vi::AbstractVarInfo, values, keys)
for vn in Base.keys(vi)
_setval_kernel!(vi, vn, values, keys)
end
return vi
end
_setval!(vi::TypedVarInfo, values, keys) = _typed_setval!(vi, vi.metadata, values, keys)
@generated function _typed_setval!(
vi::TypedVarInfo,
metadata::NamedTuple{names},
values,
keys
) where {names}
updates = map(names) do n
quote
for vn in metadata.$n.vns
_setval_kernel!(vi, vn, values, keys)
end
end
end

return quote
$(updates...)
return vi
end
end

function _setval_kernel!(vi::AbstractVarInfo, vn::VarName, values, keys)
sym = Symbol(vn)
regex = Regex("^$sym\$|^$sym\\[")
indices = findall(x -> match(regex, string(x)) !== nothing, keys)
if !isempty(indices)
sorted_indices = sort!(indices; by=i -> string(keys[i]), lt=NaturalSort.natural)
val = mapreduce(vcat, sorted_indices) do i
values[i]
end
setval!(vi, val, vn)
settrans!(vi, false, vn)
end
end
74 changes: 59 additions & 15 deletions test/prob_macro.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ Random.seed!(129)

@testset "prob_macro" begin
@testset "scalar" begin
@model demo(x) = begin
@model function demo(x)
m ~ Normal()
x ~ Normal(m, 1)
end
Expand All @@ -29,37 +29,44 @@ Random.seed!(129)
@test logprob"x = xval | m = mval, model = model" == loglike
@test logprob"x = xval, m = mval | model = model" == logjoint

varinfo = VarInfo(demo(missing))
@test logprob"x = xval, m = mval | model = model, varinfo = varinfo" == logjoint

varinfo = VarInfo(demo(xval))
@test logprob"m = mval | model = model, varinfo = varinfo" == logprior
@test logprob"m = mval | x = xval, model = model, varinfo = varinfo" == logprior
@test logprob"x = xval | m = mval, model = model, varinfo = varinfo" == loglike
varinfo = VarInfo(demo(missing))
@test logprob"x = xval, m = mval | model = model, varinfo = varinfo" == logjoint

chain = sample(demo(xval), IS(), iters; save_state = true)
chain2 = Chains(chain.value, chain.logevidence, chain.name_map, NamedTuple())
lps = logpdf.(Normal.(vec(chain["m"]), 1), xval)
lps = logpdf.(Normal.(chain["m"], 1), xval)
@test logprob"x = xval | chain = chain" == lps
@test logprob"x = xval | chain = chain2, model = model" == lps
varinfo = VarInfo(demo(xval))
@test logprob"x = xval | chain = chain, varinfo = varinfo" == lps
@test logprob"x = xval | chain = chain2, model = model, varinfo = varinfo" == lps

# multiple chains
pchain = chainscat(chain, chain)
pchain2 = chainscat(chain2, chain2)
plps = repeat(lps, 1, 2)
@test logprob"x = xval | chain = pchain" == plps
@test logprob"x = xval | chain = pchain2, model = model" == plps
@test logprob"x = xval | chain = pchain, varinfo = varinfo" == plps
@test logprob"x = xval | chain = pchain2, model = model, varinfo = varinfo" == plps
end

@testset "vector" begin
n = 5
@model demo(x, n = n, ::Type{T} = Float64) where {T} = begin
m = Vector{T}(undef, n)
@. m ~ Normal()
@. x ~ Normal.(m, 1)
@model function demo(x, n = n)
m ~ MvNormal(n, 1.0)
x ~ MvNormal(m, 1.0)
end
mval = rand(n)
xval = rand(n)
iters = 1000

logprior = sum(logpdf.(Normal(), mval))
like(m, x) = sum(logpdf.(Normal.(m, 1), x))
loglike = like(mval, xval)
logprior = logpdf(MvNormal(n, 1.0), mval)
loglike = logpdf(MvNormal(mval, 1.0), xval)
logjoint = logprior + loglike

model = demo(xval)
Expand All @@ -76,12 +83,49 @@ Random.seed!(129)
chain2 = Chains(chain.value, chain.logevidence, chain.name_map, NamedTuple())

names = namesingroup(chain, "m")
lps = map(1:iters) do iter
like([chain[iter, name, 1] for name in names], xval)
end
lps = [
logpdf(MvNormal(chain.value[i, names, j], 1.0), xval)
for i in 1:size(chain, 1), j in 1:size(chain, 3)
]
@test logprob"x = xval | chain = chain" == lps
@test logprob"x = xval | chain = chain2, model = model" == lps
@test logprob"x = xval | chain = chain, varinfo = varinfo" == lps
@test logprob"x = xval | chain = chain2, model = model, varinfo = varinfo" == lps

# multiple chains
pchain = chainscat(chain, chain)
pchain2 = chainscat(chain2, chain2)
plps = repeat(lps, 1, 2)
@test logprob"x = xval | chain = pchain" == plps
@test logprob"x = xval | chain = pchain2, model = model" == plps
@test logprob"x = xval | chain = pchain, varinfo = varinfo" == plps
@test logprob"x = xval | chain = pchain2, model = model, varinfo = varinfo" == plps
end

@testset "issue#137" begin
@model function model1(y, group, n_groups)
σ ~ truncated(Cauchy(0, 1), 0, Inf)
α ~ filldist(Normal(0, 10), n_groups)
μ = α[group]
y ~ MvNormal(μ, σ)
end

y = randn(100)
group = rand(1:4, 100)
n_groups = 4

chain1 = sample(model1(y, group, n_groups), NUTS(0.65), 2_000; save_state=true)
logprob"y = y[[1]] | group = group[[1]], n_groups = n_groups, chain = chain1"

@model function model2(y, group, n_groups)
σ ~ truncated(Cauchy(0, 1), 0, Inf)
α ~ filldist(Normal(0, 10), n_groups)
for i in 1:length(y)
y[i] ~ Normal(α[group[i]], σ)
end
end

chain2 = sample(model2(y, group, n_groups), NUTS(0.65), 2_000; save_state=true)
logprob"y = y[[1]] | group = group[[1]], n_groups = n_groups, chain = chain2"
end
end
45 changes: 45 additions & 0 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -505,4 +505,49 @@ include(dir*"/test/test_utils/AllUtils.jl")
@test vi.metadata.w.gids[1] == Set([hmc.selector])
@test vi.metadata.u.gids[1] == Set([hmc.selector])
end

@testset "setval!" begin
@model function testmodel(x)
n = length(x)
s ~ truncated(Normal(), 0, Inf)
m ~ MvNormal(n, 1.0)
x ~ MvNormal(m, s)
end

x = randn(5)
model = testmodel(x)

# UntypedVarInfo
vi = VarInfo()
model(vi, SampleFromPrior())

vicopy = deepcopy(vi)
DynamicPPL.setval!(vicopy, (m = zeros(5),))
@test vicopy[@varname(m)] == zeros(5)
@test vicopy[@varname(s)] == vi[@varname(s)]

DynamicPPL.setval!(vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...))
@test vicopy[@varname(m)] == 1:5
@test vicopy[@varname(s)] == vi[@varname(s)]

DynamicPPL.setval!(vicopy, (s = 42,))
@test vicopy[@varname(m)] == 1:5
@test vicopy[@varname(s)] == 42

# TypedVarInfo
vi = VarInfo(model)

vicopy = deepcopy(vi)
DynamicPPL.setval!(vicopy, (m = zeros(5),))
@test vicopy[@varname(m)] == zeros(5)
@test vicopy[@varname(s)] == vi[@varname(s)]

DynamicPPL.setval!(vicopy, (; (Symbol("m[$i]") => i for i in (1, 3, 5, 4, 2))...))
@test vicopy[@varname(m)] == 1:5
@test vicopy[@varname(s)] == vi[@varname(s)]

DynamicPPL.setval!(vicopy, (s = 42,))
@test vicopy[@varname(m)] == 1:5
@test vicopy[@varname(s)] == 42
end
end

0 comments on commit 63543dd

Please sign in to comment.