-
Notifications
You must be signed in to change notification settings - Fork 419
/
lkj.jl
255 lines (216 loc) · 7.63 KB
/
lkj.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
"""
LKJ(d, η)
```julia
d::Int dimension
η::Real positive shape
```
The [LKJ](https://doi.org/10.1016/j.jmva.2009.04.008) distribution is a distribution over
``d\\times d`` real correlation matrices (positive-definite matrices with ones on the diagonal).
If ``\\mathbf{R}\\sim \\textrm{LKJ}_{d}(\\eta)``, then its probability density function is
```math
f(\\mathbf{R};\\eta) = \\left[\\prod_{k=1}^{d-1}\\pi^{\\frac{k}{2}}
\\frac{\\Gamma\\left(\\eta+\\frac{d-1-k}{2}\\right)}{\\Gamma\\left(\\eta+\\frac{d-1}{2}\\right)}\\right]^{-1}
|\\mathbf{R}|^{\\eta-1}.
```
If ``\\eta = 1``, then the LKJ distribution is uniform over
[the space of correlation matrices](https://www.jstor.org/stable/2684832).
"""
struct LKJ{T <: Real, D <: Integer} <: ContinuousMatrixDistribution
d::D
η::T
logc0::T
end
# -----------------------------------------------------------------------------
# Constructors
# -----------------------------------------------------------------------------
function LKJ(d::Integer, η::Real; check_args = true)
if check_args
d > 0 || throw(ArgumentError("Matrix dimension must be positive."))
η > 0 || throw(ArgumentError("Shape parameter must be positive."))
end
logc0 = lkj_logc0(d, η)
T = Base.promote_eltype(η, logc0)
LKJ{T, typeof(d)}(d, T(η), T(logc0))
end
# -----------------------------------------------------------------------------
# REPL display
# -----------------------------------------------------------------------------
show(io::IO, d::LKJ) = show_multline(io, d, [(:d, d.d), (:η, d.η)])
# -----------------------------------------------------------------------------
# Conversion
# -----------------------------------------------------------------------------
function convert(::Type{LKJ{T}}, d::LKJ) where T <: Real
LKJ{T, typeof(d.d)}(d.d, T(d.η), T(d.logc0))
end
function convert(::Type{LKJ{T}}, d::Integer, η, logc0) where T <: Real
LKJ{T, typeof(d)}(d, T(η), T(logc0))
end
# -----------------------------------------------------------------------------
# Properties
# -----------------------------------------------------------------------------
dim(d::LKJ) = d.d
size(d::LKJ) = (dim(d), dim(d))
rank(d::LKJ) = dim(d)
insupport(d::LKJ, R::AbstractMatrix) = isreal(R) && size(R) == size(d) && isone(Diagonal(R)) && isposdef(R)
mean(d::LKJ) = Matrix{partype(d)}(I, dim(d), dim(d))
function mode(d::LKJ; check_args = true)
p, η = params(d)
if check_args
η > 1 || throw(ArgumentError("mode is defined only when η > 1."))
end
return mean(d)
end
function var(lkj::LKJ)
d = dim(lkj)
d > 1 || return zeros(d, d)
σ² = var(_marginal(lkj))
σ² * (ones(partype(lkj), d, d) - I)
end
params(d::LKJ) = (d.d, d.η)
@inline partype(d::LKJ{T}) where {T <: Real} = T
# -----------------------------------------------------------------------------
# Evaluation
# -----------------------------------------------------------------------------
function lkj_logc0(d::Integer, η::Real)
d > 1 || return zero(η)
if isone(η)
if iseven(d)
logc0 = -lkj_onion_loginvconst_uniform_even(d)
else
logc0 = -lkj_onion_loginvconst_uniform_odd(d)
end
else
logc0 = -lkj_onion_loginvconst(d, η)
end
return logc0
end
logkernel(d::LKJ, R::AbstractMatrix) = (d.η - 1) * logdet(R)
# -----------------------------------------------------------------------------
# Sampling
# -----------------------------------------------------------------------------
function _rand!(rng::AbstractRNG, d::LKJ, R::AbstractMatrix)
R .= _lkj_onion_sampler(d.d, d.η, rng)
end
function _lkj_onion_sampler(d::Integer, η::Real, rng::AbstractRNG = Random.GLOBAL_RNG)
# Section 3.2 in LKJ (2009 JMA)
# 1. Initialization
R = ones(typeof(η), d, d)
d > 1 || return R
β = η + 0.5d - 1
u = rand(rng, Beta(β, β))
R[1, 2] = 2u - 1
R[2, 1] = R[1, 2]
# 2.
for k in 2:d - 1
# (a)
β -= 0.5
# (b)
y = rand(rng, Beta(k / 2, β))
# (c)
u = randn(rng, k)
u = u / norm(u)
# (d)
w = sqrt(y) * u
A = cholesky(R[1:k, 1:k]).L
z = A * w
# (e)
R[1:k, k + 1] = z
R[k + 1, 1:k] = z'
end
# 3.
return R
end
# -----------------------------------------------------------------------------
# The free elements of an LKJ matrix each have the same marginal distribution
# -----------------------------------------------------------------------------
function _marginal(lkj::LKJ)
d = lkj.d
η = lkj.η
α = η + 0.5d - 1
LocationScale(-1, 2, Beta(α, α))
end
# -----------------------------------------------------------------------------
# Test utils
# -----------------------------------------------------------------------------
function _univariate(d::LKJ)
check_univariate(d)
return DiscreteNonParametric([one(d.η)], [one(d.η)])
end
function _rand_params(::Type{LKJ}, elty, n::Int, p::Int)
n == p || throw(ArgumentError("dims must be equal for LKJ"))
η = abs(3randn(elty))
return n, η
end
# -----------------------------------------------------------------------------
# Several redundant implementations of the recipricol integrating constant.
# If f(R; n) = c₀ |R|ⁿ⁻¹, these give log(1 / c₀).
# Every integrating constant formula given in LKJ (2009 JMA) is an expression
# for 1 / c₀, even if they say that it is not.
# -----------------------------------------------------------------------------
function lkj_onion_loginvconst(d::Integer, η::Real)
# Equation (17) in LKJ (2009 JMA)
sumlogs = zero(η)
for k in 2:d - 1
sumlogs += 0.5k*logπ + loggamma(η + 0.5(d - 1 - k))
end
α = η + 0.5d - 1
loginvconst = (2η + d - 3)*logtwo + logbeta(α, α) + sumlogs - (d - 2) * loggamma(η + 0.5(d - 1))
return loginvconst
end
function lkj_onion_loginvconst_uniform_odd(d::Integer)
# Theorem 5 in LKJ (2009 JMA)
sumlogs = 0.0
for k in 1:div(d - 1, 2)
sumlogs += loggamma(2k)
end
loginvconst = 0.25(d^2 - 1)*logπ + sumlogs - 0.25(d - 1)^2*logtwo - (d - 1)*loggamma(0.5(d + 1))
return loginvconst
end
function lkj_onion_loginvconst_uniform_even(d::Integer)
# Theorem 5 in LKJ (2009 JMA)
sumlogs = 0.0
for k in 1:div(d - 2, 2)
sumlogs += loggamma(2k)
end
loginvconst = 0.25d*(d - 2)*logπ + 0.25(3d^2 - 4d)*logtwo + d*loggamma(0.5d) + sumlogs - (d - 1)*loggamma(d)
end
function lkj_vine_loginvconst(d::Integer, η::Real)
# Equation (16) in LKJ (2009 JMA)
expsum = zero(η)
betasum = zero(η)
for k in 1:d - 1
α = η + 0.5(d - k - 1)
expsum += (2η - 2 + d - k) * (d - k)
betasum += (d - k) * logbeta(α, α)
end
loginvconst = expsum * logtwo + betasum
return loginvconst
end
function lkj_vine_loginvconst_uniform(d::Integer)
# Equation after (16) in LKJ (2009 JMA)
expsum = 0.0
betasum = 0.0
for k in 1:d - 1
α = (k + 1) / 2
expsum += k ^ 2
betasum += k * logbeta(α, α)
end
loginvconst = expsum * logtwo + betasum
return loginvconst
end
function lkj_loginvconst_alt(d::Integer, η::Real)
# Third line in first proof of Section 3.3 in LKJ (2009 JMA)
loginvconst = zero(η)
for k in 1:d - 1
loginvconst += 0.5k*logπ + loggamma(η + 0.5(d - 1 - k)) - loggamma(η + 0.5(d - 1))
end
return loginvconst
end
function corr_logvolume(n::Integer)
# https://doi.org/10.4169/amer.math.monthly.123.9.909
logvol = 0.0
for k in 1:n - 1
logvol += 0.5k*logπ + k*loggamma((k+1)/2) - k*loggamma((k+2)/2)
end
return logvol
end