-
Notifications
You must be signed in to change notification settings - Fork 219
/
is.jl
71 lines (58 loc) · 1.37 KB
/
is.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""
IS()
Importance sampling algorithm.
Usage:
```julia
IS()
```
Example:
```julia
# Define a simple Normal model with unknown mean and variance.
@model function gdemo(x)
s² ~ InverseGamma(2,3)
m ~ Normal(0,sqrt.(s))
x[1] ~ Normal(m, sqrt.(s))
x[2] ~ Normal(m, sqrt.(s))
return s², m
end
sample(gdemo([1.5, 2]), IS(), 1000)
```
"""
struct IS{space} <: InferenceAlgorithm end
IS() = IS{()}()
DynamicPPL.initialsampler(sampler::Sampler{<:IS}) = sampler
function DynamicPPL.initialstep(
rng::AbstractRNG,
model::Model,
spl::Sampler{<:IS},
vi::AbstractVarInfo;
kwargs...
)
return Transition(vi), nothing
end
function AbstractMCMC.step(
rng::Random.AbstractRNG,
model::Model,
spl::Sampler{<:IS},
::Nothing;
kwargs...
)
vi = VarInfo(rng, model, spl)
return Transition(vi), nothing
end
# Calculate evidence.
function getlogevidence(samples::Vector{<:Transition}, ::Sampler{<:IS}, state)
return logsumexp(map(x -> x.lp, samples)) - log(length(samples))
end
function DynamicPPL.assume(rng, spl::Sampler{<:IS}, dist::Distribution, vn::VarName, vi)
if haskey(vi, vn)
r = vi[vn]
else
r = rand(rng, dist)
vi = push!!(vi, vn, r, dist, spl)
end
return r, 0, vi
end
function DynamicPPL.observe(spl::Sampler{<:IS}, dist::Distribution, value, vi)
return logpdf(dist, value), vi
end