Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add LKJ Matrix Distribution #108

Closed
joshualeond opened this issue May 28, 2020 · 17 comments · Fixed by #125
Closed

Add LKJ Matrix Distribution #108

joshualeond opened this issue May 28, 2020 · 17 comments · Fixed by #125

Comments

@joshualeond
Copy link

I'm currently trying to build a model in Turing that utilizes the newly added LKJ distribution from Distributions.jl. I ran into issues when I attempted to sample the model:

ERROR: MethodError: no method matching bijector(::LKJ{Float64,Int64})
Closest candidates are:
  bijector(::KSOneSided) at /Users/joshualeond/.julia/packages/Bijectors/bHaf6/src/transformed_distribution.jl:58
  bijector(::Product{Discrete,T,V} where V<:AbstractArray{T,1} where T<:Distribution{Univariate,Discrete}) at /Users/joshualeond/.julia/packages/Bijectors/bHaf6/src/transformed_distribution.jl:39
  bijector(::Product{Continuous,T,Tdists} where Tdists<:(FillArrays.Fill{T,1,Axes} where Axes) where T<:Distribution{Univariate,Continuous}) at /Users/joshualeond/.julia/packages/Bijectors/bHaf6/src/compat/distributionsad.jl:16

It appears that the LKJ distribution maybe needs to be added here in Bijectors:

const PDMatDistribution = Union{MatrixBeta, InverseWishart, Wishart}

However, after I simply added LKJ to this line I hit the following error:

ERROR: MethodError: no method matching getlogp(::LKJ{Float64,Int64}, ::Cholesky{Float64,Array{Float64,2}}, ::Array{Float64,2})
Closest candidates are:
  getlogp(::MatrixBeta, ::Any, ::Any) at C:\Users\dunjos0\.julia\dev\Bijectors\src\Bijectors.jl:231
  getlogp(::Wishart, ::Any, ::Any) at C:\Users\dunjos0\.julia\dev\Bijectors\src\Bijectors.jl:236
  getlogp(::InverseWishart, ::Any, ::Any) at C:\Users\dunjos0\.julia\dev\Bijectors\src\Bijectors.jl:239

So it looks like there needs to be a new getlogp method in Bijectors for the LKJ. I see that these getlogp methods follow closely to the logkernel methods in Distributions.jl but am unsure what needs to be adjusted to make sure it works nicely with Turing.

@mohamed82008
Copy link
Member

Just defining Bijectors.bijector(::LKJ) = PDBijector() should be enough I think.

@joshualeond
Copy link
Author

Thanks @mohamed82008, I defined Bijectors.bijector(::LKJ) as you said and that's gotten me past the original ERROR: MethodError: no method matching bijector(::LKJ{Float64,Int64}) errors. I think I may be up against a user error now though. Here's a short reproducible example attempting to use the LKJ after the bijector was defined:

using Turing, Bijectors, LinearAlgebra, Random
Random.seed!(666)

# generate data
sigma = [1,2,3]
Omega = [1 0.3 0.2;
        0.3 1 0.1;
        0.2 0.1 1]

Sigma = diagm(sigma) * Omega * diagm(sigma)
N = 100
J = 3
y = rand(MvNormal(zeros(J), Sigma), N)'

# model
@model correlation(J, N, y, Zero) = begin
    sigma ~ filldist(truncated(Cauchy(0., 5.), 0., Inf), J) # prior on the standard deviations
    Omega ~ LKJ(J, 1) # LKJ prior on the correlation matrix
    # covariance matrix
    Sigma = diagm(sigma) * Omega * diagm(sigma)

    for i in 1:N
        y[i,:] ~ MvNormal(Zero, Sigma) # sampling distribution of the observations
    end
end

Bijectors.bijector(::LKJ) = PDBijector()

# attempt to recover parameters
chain = sample(correlation(J, N, y, zeros(J)), NUTS(), 1000)

And the error:

ERROR: PosDefException: matrix is not Hermitian; Cholesky factorization failed.

Perhaps I misspecified something here?

@torfjelde
Copy link
Member

torfjelde commented May 30, 2020

It seems like the issue is with your Sigma = diagm(sigma) * Omega * diagm(sigma) line.

julia> # model
       @model correlation(J, N, y, Zero) = begin
           sigma ~ filldist(truncated(Cauchy(0., 5.), 0., Inf), J) # prior on the standard deviations
           Omega ~ LKJ(J, 1) # LKJ prior on the correlation matrix
           @info isposdef(Omega)
           # covariance matrix
           Sigma = diagm(sigma) * Omega * diagm(sigma)
           @info Sigma
           @info isposdef(Sigma)

           for i in 1:N
               y[i,:] ~ MvNormal(Zero, Sigma) # sampling distribution of the observations
           end
       end
DynamicPPL.ModelGen{var"###generator#300",(:J, :N, :y, :Zero),(),Tuple{}}(##generator#300, NamedTuple())

julia> 

julia> m = correlation(J, N, y, zeros(J));

julia> m()
[ Info: true
[ Info: [2.43118945431723 0.8857326262243734 -0.38161465118558274; 0.8857326262243734 0.496164416668245 -0.2301838107162483; -0.38161465118558274 -0.2301838107162483 0.28572768840562973]
[ Info: true

julia> m()
[ Info: true
[ Info: [289.05898695571943 -3.900716770274498 -45.83646433466277; -3.900716770274498 0.3260931391923026 2.375071240046279; -45.83646433466277 2.3750712400462786 74.5652927508229]
[ Info: false
ERROR: PosDefException: matrix is not Hermitian; Cholesky factorization failed.

EDIT: haha, nevermind I'm stupid 🙃 The above is relevant, but I deleted parts of my comment that was just me brain-farting like crazy.

@joshualeond
Copy link
Author

Thanks for checking out my example @torfjelde! So the diagm(sigma) * Omega * diagm(sigma) is my attempt at reproducing the quad_form_diag available in Stan. I've seen the LKJ used as a prior on the correlation matrix in Stan like the following:

data {
  int<lower=1> N; // number of observations
  int<lower=1> J; // dimension of observations
  vector[J] y[N]; // observations
  vector[J] Zero; // a vector of Zeros (fixed means of observations)
}
parameters {
  corr_matrix[J] Omega; 
  vector<lower=0>[J] sigma; 
}
transformed parameters {
  cov_matrix[J] Sigma; 
  Sigma <- quad_form_diag(Omega, sigma); 
}
model {
  y ~ multi_normal(Zero,Sigma); // sampling distribution of the observations
  sigma ~ cauchy(0, 5); // prior on the standard deviations
  Omega ~ lkj_corr(1); // LKJ prior on the correlation matrix 
}

The Stan docs make it seem pretty straight forward but perhaps there's more going on then I realize:

matrix quad_form_diag(matrix m, vector v)
The quadratic form using the column vector v as a diagonal matrix, i.e., diag_matrix(v) * m * diag_matrix(v).

@devmotion
Copy link
Member

Completely unrelated, but I think you should use Diagonal instead of diagm, since the former does not actually allocate a matrix whereas the latter does. A simple benchmark:

julia> using BenchmarkTools, LinearAlgebra

julia> f(v) = Diagonal(v)

julia> g(v) = diagm(v)

julia> @btime f($(rand(100));
  4.866 ns (1 allocation: 16 bytes)

julia> @btime g($(rand(100));
  5.854 μs (3 allocations: 78.23 KiB)

@devmotion
Copy link
Member

These issues with positive definiteness are nasty, even if the matrix is guaranteed to be positive (semi-)definite mathematically it can easily happen that due to numerical issues it's not positive (semi-)definite numerically (I've run into this problem multiple times when parameterizing MvNormal with estimated covariance matrices). Sometimes wrapping the matrix into Symmetric helps, but lately only https://github.com/timholy/PositiveFactorizations.jl could fix my numerical issues. However, I've never tried to use it together with Turing, so I'm not sure if AD works with that package.

@torfjelde
Copy link
Member

Sometimes wrapping the matrix into Symmetric helps

@devmotion to the rescue!

julia> # model
       @model correlation(J, N, y, Zero) = begin
           sigma ~ filldist(truncated(Cauchy(0., 5.), 0., Inf), J) # prior on the standard deviations
           Omega ~ LKJ(J, 1) # LKJ prior on the correlation matrix
           @info sigma
           @info Omega
           @info isposdef(Omega)
           # covariance matrix
           Sigma = Symmetric(Diagonal(sigma) * Omega * Diagonal(sigma))
           @info Sigma
           @info isposdef(Sigma)

           for i in 1:N
               y[i,:] ~ MvNormal(Zero, Sigma) # sampling distribution of the observations
           end
       end
DynamicPPL.ModelGen{var"###generator#348",(:J, :N, :y, :Zero),(),Tuple{}}(##generator#348, NamedTuple())

julia> m = correlation(J, N, y, zeros(J));

julia> m()
[ Info: [6.554761110166019, 2.34510326436457, 0.6888832327192145]
[ Info: [1.0 0.47103043693813107 0.5681366777033788; 0.47103043693813107 1.0 -0.41754820600773; 0.5681366777033788 -0.41754820600773 1.0]
[ Info: true
[ Info: [42.96489321134486 7.24048754385414 2.5654012966083335; 7.24048754385414 5.499509320533363 -0.674550094605337; 2.5654012966083335 -0.674550094605337 0.4745601083216754]
[ Info: true

julia> m()
[ Info: [39.64435704629228, 0.9699293307273142, 3.1448223201647334]
[ Info: [1.0 0.051875690114511874 -0.052259251243831094; 0.051875690114511874 1.0 0.8063928903464262; -0.052259251243831094 0.8063928903464262 1.0]
[ Info: true
[ Info: [1571.675045613904 1.9947356925964466 -6.515393871749325; 1.9947356925964466 0.9407629066051357 2.4597042749565188; -6.515393871749325 2.4597042749565188 9.889907425406298]
[ Info: true

julia> m()
[ Info: [0.8388798227529882, 3.3189729374771266, 8.826661070228887]
[ Info: [1.0 0.7104582898219662 -0.4776195773069013; 0.7104582898219662 1.0 -0.07252323132333671; -0.4776195773069013 -0.07252323132333671 1.0]
[ Info: true
[ Info: [0.703719357022085 1.978071774380738 -3.536537920990547; 1.978071774380738 11.015581359705546 -2.124600640530144; -3.536537920990547 -2.124600640530144 77.90994564869416]
[ Info: true

julia> m()
[ Info: [2.052493460163989, 18.75614790558376, 3.1610610751936297]
[ Info: [1.0 -0.2180845855069199 -0.42228217488217845; -0.2180845855069199 1.0 0.09044611498906638; -0.42228217488217845 0.09044611498906638 1.0]
[ Info: true
[ Info: [4.212729404015944 -8.395574136610355 -2.73979089842532; -8.395574136610355 351.793084256134 5.362489474229928; -2.73979089842532 5.362489474229928 9.992307121104306]
[ Info: true

julia> m()
[ Info: [671.1873925742121, 0.3471925977453725, 12.611536505760812]
[ Info: [1.0 0.6777757171051204 -0.10065195388769885; 0.6777757171051204 1.0 0.6147688385377347; -0.10065195388769885 0.6147688385377347 1.0]
[ Info: true
[ Info: [450492.51595056953 157.94295267110346 -851.9890272445987; 157.94295267110346 0.12054269992918003 2.6918465834085405; -851.9890272445987 2.6918465834085405 159.05085303613762]
[ Info: true

julia> m()
[ Info: [0.8079920565082861, 4.041301143186047, 38.36495641353685]
[ Info: [1.0 0.35915995608688434 -0.19698184713779998; 0.35915995608688434 1.0 0.5033179678024089; -0.19698184713779998 0.5033179678024089 1.0]
[ Info: true
[ Info: [0.6528511633804894 1.1727790914573786 -6.1061575530419185; 1.1727790914573786 16.332114929916848 78.03660324156078; -6.1061575530419185 78.03660324156078 1471.8698806125824]
[ Info: true

I'm assuming that the reason why this works is that there exists more numerically stable method for symmetric matrices and by wrapping it in Symmetric you'll make sure to dispatch to the correct method:)

@devmotion
Copy link
Member

Unfortunately, the simple (and slightly inefficient) reason for it is that in https://github.com/JuliaStats/PDMats.jl/blob/00804c3ca96a0839c03d25782a51028fe96fa725/src/pdmat.jl#L20 a new matrix is allocated in which just the upper triangle is mirrored to the lower one. Hence if there was any numerical discrepancy between those, it should be gone afterwards. That's also the reason why it doesn't fix the issues always (according to my experience).

@devmotion
Copy link
Member

Maybe you could avoid that by using

using PDMats, LinearAlgebra
...
_Sigma = Symmetric(Diagonal(sigma) * Omega * Diagonal(sigma))
Sigma = PDMat(_Sigma, cholesky(_Sigma))
....

since it seems there exists an implementation of cholesky for Symmetric which might be more efficient (see, e.g., https://github.com/JuliaLang/julia/blob/7301dc61bdeb5d66e94e15bdfcd4c54f7c90f068/stdlib/LinearAlgebra/src/cholesky.jl#L217-L221). I'm wondering why that is not the default in PDMats 🤔

@joshualeond
Copy link
Author

Thanks for all the tips, I tried out your suggestion with the PDMat but ran into the following error:

ERROR: MethodError: no method matching PDMat{Float64,Symmetric{Float64,Array{Float64,2}}}(::Int64, ::Symmetric{Float64,Array{Float64,2}}, ::Cholesky{Float64,Array{Float64,2}})

If I removed the second argument in PDMat then it did sample with HMC but had some odd results:

using Turing, Distributions, LinearAlgebra, Random, Bijectors, PDMats
Bijectors.bijector(d::LKJ) = Bijectors.PDBijector()

Random.seed!(666)
# generate data
sigma = [1,2,3]
Omega = [1 0.3 0.2;
        0.3 1 0.1;
        0.2 0.1 1]

Sigma = Diagonal(sigma) * Omega * Diagonal(sigma)
N = 100
J = 3
y = rand(MvNormal(zeros(J), Sigma), N)'

# model
@model correlation(J, N, y, Zero) = begin
    sigma ~ filldist(truncated(Cauchy(0., 5.), 0., Inf), J) # prior on the standard deviations
    Omega ~ LKJ(J, 1) # LKJ prior on the correlation matrix

    _Sigma = Symmetric(Diagonal(sigma) * Omega * Diagonal(sigma))
    Sigma = PDMat(_Sigma)

    for i in 1:N
        y[i,:] ~ MvNormal(Zero, Sigma) # sampling distribution of the observations
    end
    return Sigma
end

chain = sample(correlation(J, N, y, zeros(J)), HMC(0.01, 5), 1000)
chain = sample(correlation(J, N, y, zeros(J)), NUTS(), 1000)
Summary Statistics
   parameters     mean       std  naive_se     mcse      ess   r_hat
  ───────────  ───────  ────────  ────────  ───────  ───────  ──────
  Omega[1, 1]  87.4067  183.9117    5.8158  50.4635   4.6011  1.2026
  Omega[1, 2]  14.2648   21.3900    0.6764   6.7411   4.1112  1.4768
  Omega[1, 3]   4.8683    9.4294    0.2982   2.8001   4.7358  1.2381
  Omega[2, 1]  14.2648   21.3900    0.6764   6.7411   4.1112  1.4768
  Omega[2, 2]  24.3924   35.9843    1.1379  11.4285   4.1922  1.4981
  Omega[2, 3]   0.8385    1.8374    0.0581   0.3172  23.5443  1.1229
  Omega[3, 1]   4.8683    9.4294    0.2982   2.8001   4.7358  1.2381
  Omega[3, 2]   0.8385    1.8374    0.0581   0.3172  23.5443  1.1229
  Omega[3, 3]  10.0522   13.8718    0.4387   4.4399   4.0161  1.5990
     sigma[1]   0.2820    0.2060    0.0065   0.0659   4.0161  1.8310
     sigma[2]   1.1579    0.8786    0.0278   0.2863   4.0161  2.4314
     sigma[3]   2.5682    2.3649    0.0748   0.7436   4.0161  1.8521

The sigmas aren't too far off but the correlation matrix Omega has some relatively large numbers on the diagonal that should be 1. Oddly enough, if I sample with NUTS I end up with the old error again:

ERROR: PosDefException: matrix is not Hermitian; Cholesky factorization failed.

@devmotion
Copy link
Member

Thanks for all the tips, I tried out your suggestion with the PDMat but ran into the following error:

Ah, then probably that's the reason for why PDMats doesn't use Symmetric directly 😄 So I guess, just remove PDMats, and just pass the Symmetric matrix to MvNormal directly.

BTW probably you should also remove Zeros: if you don't pass a mean vector, MvNormal will automatically have a mean of zero (and use something that's more optimized than just zeros(J)).

@joshualeond
Copy link
Author

Good point on the Zeros, I've removed them and the PDMat:

using Turing, Distributions, LinearAlgebra, Random, Bijectors
Bijectors.bijector(d::LKJ) = Bijectors.PDBijector()

Random.seed!(666)
# generate data
sigma = [1,2,3]
Omega = [1 0.3 0.2;
        0.3 1 0.1;
        0.2 0.1 1]

Sigma = Diagonal(sigma) * Omega * Diagonal(sigma)
N = 100
J = 3
y = rand(MvNormal(Sigma), N)'

# model
@model correlation(J, N, y) = begin
    sigma ~ filldist(truncated(Cauchy(0., 5.), 0., Inf), J) # prior on the standard deviations
    Omega ~ LKJ(J, 1) # LKJ prior on the correlation matrix

    Sigma = Symmetric(Diagonal(sigma) * Omega * Diagonal(sigma))

    for i in 1:N
        y[i,:] ~ MvNormal(Sigma) # sampling distribution of the observations
    end
    return Sigma
end

chain = sample(correlation(J, N, y), HMC(0.01, 5), 1000)

With HMC, similar results as before. Not quite recovering the original parameters:

Summary Statistics
   parameters     mean      std  naive_se     mcse     ess   r_hat
  ───────────  ───────  ───────  ────────  ───────  ──────  ──────
  Omega[1, 1]  14.0667  15.5562    0.4919   4.8148  4.0161  2.0984
  Omega[1, 2]   3.0884   3.3027    0.1044   1.0634  4.0161  1.9885
  Omega[1, 3]   1.2705   1.4212    0.0449   0.4322  4.0541  1.9204
  Omega[2, 1]   3.0884   3.3027    0.1044   1.0634  4.0161  1.9885
  Omega[2, 2]   7.4755   9.5501    0.3020   3.0134  4.1477  1.5860
  Omega[2, 3]   0.4702   0.6166    0.0195   0.1384  5.8176  1.3960
  Omega[3, 1]   1.2705   1.4212    0.0449   0.4322  4.0541  1.9204
  Omega[3, 2]   0.4702   0.6166    0.0195   0.1384  5.8176  1.3960
  Omega[3, 3]  25.0593  37.7232    1.1929  11.9310  4.0702  1.5460
     sigma[1]   0.6272   0.5028    0.0159   0.1633  4.0161  2.1600
     sigma[2]   1.5641   1.0518    0.0333   0.3419  4.0161  2.0743
     sigma[3]   2.0635   1.7036    0.0539   0.5489  4.0161  2.5594

With NUTS:

ERROR: PosDefException: matrix is not positive definite; Cholesky factorization failed.

But sometimes when I sample with NUTS I'm actually seeing VERY large numbers, like 1e155 large.

@joshualeond
Copy link
Author

I had originally opened an issue on the following repo where @trappmartin ended up giving me some advice on this particular issue with the LKJ. I wanted to bring some info from that issue over to this one and close that original issue.

On the other issue Martin ended up restructuring the model specification like the following:

@model correlation(J, N, y) = begin
    sigma ~ filldist(truncated(Cauchy(0., 5.), 0., Inf), J) # prior on the standard deviations
    Omega ~ LKJ(J, 1) # LKJ prior on the correlation matrix

    L = Diagonal(sigma) * Omega

    for i in 1:N
        y[i,:] ~ MvNormal(L*L') # sampling distribution of the observations
    end
    return L*L'
end

However, even with this I'm still seeing issues with sampling the posterior distribution. Here's an example of the data sampled from the prior for this model:

sample(correlation(J, N, y), Prior(), 2000)
Summary Statistics
   parameters     mean       std  naive_se     mcse        ess   r_hat
  ───────────  ───────  ────────  ────────  ───────  ─────────  ──────
  Omega[1, 1]   1.0000    0.0000    0.0000   0.0000        NaN     NaN
  Omega[1, 2]  -0.0021    0.4940    0.0110   0.0138  1827.3121  1.0008
  Omega[1, 3]   0.0129    0.5078    0.0114   0.0121  2202.8359  0.9996
  Omega[2, 1]  -0.0021    0.4940    0.0110   0.0138  1827.3121  1.0008
  Omega[2, 2]   1.0000    0.0000    0.0000   0.0000        NaN     NaN
  Omega[2, 3]   0.0027    0.5036    0.0113   0.0109  2002.3309  0.9997
  Omega[3, 1]   0.0129    0.5078    0.0114   0.0121  2202.8359  0.9996
  Omega[3, 2]   0.0027    0.5036    0.0113   0.0109  2002.3309  0.9997
  Omega[3, 3]   1.0000    0.0000    0.0000   0.0000        NaN     NaN
     sigma[1]  30.3646  333.1480    7.4494   6.9035  2057.3155  0.9997
     sigma[2]  39.5240  864.4344   19.3293  18.8958  2018.9760  0.9999
     sigma[3]  36.1043  548.1510   12.2570  11.4607  2024.4883  0.9999

Quantiles
   parameters     2.5%    25.0%    50.0%    75.0%     97.5%
  ───────────  ───────  ───────  ───────  ───────  ────────
  Omega[1, 1]   1.0000   1.0000   1.0000   1.0000    1.0000
  Omega[1, 2]  -0.8740  -0.4035  -0.0009   0.4025    0.8551
  Omega[1, 3]  -0.8926  -0.4015   0.0441   0.4347    0.8704
  Omega[2, 1]  -0.8740  -0.4035  -0.0009   0.4025    0.8551
  Omega[2, 2]   1.0000   1.0000   1.0000   1.0000    1.0000
  Omega[2, 3]  -0.8760  -0.4089   0.0092   0.3951    0.8914
  Omega[3, 1]  -0.8926  -0.4015   0.0441   0.4347    0.8704
  Omega[3, 2]  -0.8760  -0.4089   0.0092   0.3951    0.8914
  Omega[3, 3]   1.0000   1.0000   1.0000   1.0000    1.0000
     sigma[1]   0.2368   2.2082   5.1726  12.5942  149.0269
     sigma[2]   0.2114   2.1148   4.9539  12.9399  117.8365
     sigma[3]   0.2138   1.9621   4.8745  11.4329  138.6144

So the prior samples look good with the 1s on the diagonal of the correlation matrix. However, after sampling with HMC we get results like what I've shown previously with low ess, high r_hat, and the estimates not respecting the LKJ prior:

sample(correlation(J, N, y), HMC(0.01, 5), 2000)
Summary Statistics
   parameters    mean     std  naive_se    mcse      ess   r_hat
  ───────────  ──────  ──────  ────────  ──────  ───────  ──────
  Omega[1, 1]  0.5430  0.3717    0.0083  0.0797   8.0321  1.7048
  Omega[1, 2]  0.0781  0.0555    0.0012  0.0106   8.3276  1.1574
  Omega[1, 3]  0.0718  0.0511    0.0011  0.0069  10.2790  1.2537
  Omega[2, 1]  0.0781  0.0555    0.0012  0.0106   8.3276  1.1574
  Omega[2, 2]  0.5147  0.2778    0.0062  0.0601  10.2855  0.9997
  Omega[2, 3]  0.0614  0.0479    0.0011  0.0049  41.6326  1.0054
  Omega[3, 1]  0.0718  0.0511    0.0011  0.0069  10.2790  1.2537
  Omega[3, 2]  0.0614  0.0479    0.0011  0.0049  41.6326  1.0054
  Omega[3, 3]  2.3791  0.9044    0.0202  0.1800  19.8614  0.9999
     sigma[1]  2.7542  1.6814    0.0376  0.3638   8.0321  1.9013
     sigma[2]  5.1374  2.6200    0.0586  0.5586  10.9404  1.0535
     sigma[3]  1.3219  0.4738    0.0106  0.0955  15.5824  1.0331

Quantiles
   parameters     2.5%   25.0%   50.0%   75.0%    97.5%
  ───────────  ───────  ──────  ──────  ──────  ───────
  Omega[1, 1]   0.1534  0.2320  0.4527  0.7055   1.4918
  Omega[1, 2]   0.0227  0.0439  0.0614  0.0898   0.2528
  Omega[1, 3]  -0.0006  0.0356  0.0615  0.1008   0.1993
  Omega[2, 1]   0.0227  0.0439  0.0614  0.0898   0.2528
  Omega[2, 2]   0.1765  0.2989  0.4679  0.6571   1.3056
  Omega[2, 3]  -0.0212  0.0290  0.0547  0.0876   0.1764
  Omega[3, 1]  -0.0006  0.0356  0.0615  0.1008   0.1993
  Omega[3, 2]  -0.0212  0.0290  0.0547  0.0876   0.1764
  Omega[3, 3]   1.1109  1.7512  2.2240  2.8160   4.7351
     sigma[1]   0.6662  1.4098  2.1513  4.2178   6.2623
     sigma[2]   1.5756  3.1137  4.3573  6.7164  11.0704
     sigma[3]   0.5932  0.9867  1.2341  1.5655   2.4542

Things seem to become much less stable when using NUTS as well. Like I mentioned before, if it finishes sampling I end up with very large explosive parameters.

I was hoping that there was a simple solution for using the LKJ distribution with Turing but according to Martin it sounds like something that may be lower level:

I think it’s mostly an issue related to the constraints of LKJ. And an issue on Turing or Bijectors is the best place for it.

If there's anything you'd like me to test and report back to you then I'm definitely willing to help.

@jfb-h
Copy link

jfb-h commented Jul 16, 2020

Just to add to this, I tried fitting a hierarchical linear model with a multivariate prior using StatsBase.cor2cov() (I think I just got lucky and this is not generally safe in terms of the PosDefException):

@model hlm(
    y, X, ll,
    ::Type{T}=Vector{Float64},
    ::Type{S}=Matrix{Float64}) where {T, S} = begin

    N, K = size(X)
    L = maximum(ll)
    τ = T(undef, K)
    β = S(undef, L, K)
    Ω = S(undef, K, K)

    τ .~ Exponential(2)
    d = Diagonal(τ)
    Ω ~ LKJ(K, 3)
    Σ = cor2cov(Ω, τ)
    for l in 1:L
        β[l,:] ~ MvNormal(zeros(K), Σ)
    end

    μ = reshape(sum(β[ll,:] .* X, dims = 2), N)
    σ ~ Exponential(.5)
    y ~ MvNormal(μ, σ)
end

Similar as in @joshualeond 's case, sampling from the prior works fine but NUTS() goes nuts :) with exploding numbers and a bunch of rejected proposals due to numerical errors.

@yiyuezhuo
Copy link
Contributor

yiyuezhuo commented Jul 28, 2020

I can't see why Bijectors.bijector(::LKJ) = PDBijector() is even possible a method to solve the problem. corr_matrix[J] Omega; contrained parameters on "correlation matrix" space (which diagonal elements are ones), while PDBijector contrain parameters on general "covariance matrix" space (or Positive Definite matrix space). It even hide problem much that logpdf defined in KLJ in Distributions.jl doesn't check if its input is a correlation matrix.

As you may notice, Omega value tend to infinity:

image

and sigma tend to 0:

image

It's the result of LKJ density:

image

Obviously, omega will be driven to infinity to maximize the density while sigma tend to 0 to neutralize the corresponding effect on Sigma.

The illusion that HMC works just shows that the wrong "warm up" phase has not ended. If you run HMC or NUTS long enough, the strange big Omage or breaking is inevitable.

So we must define a bijector dedicated for correlation matrix like corr_matrix in Stan. I will send a PR later.

@joshualeond
Copy link
Author

@yiyuezhuo There's another julia repo outside of the Turing org that may have some helpful code for reference: https://github.com/tpapp/TransformVariables.jl

Specifically the TransformVariables.CorrCholeskyFactor code.

@yiyuezhuo
Copy link
Contributor

His implementation looks fine, but doesn't support Zygote (not mutation-free). Anyway it's still inspiring.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

6 participants