Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Type instability with Dirichlet distribution #1276

Closed
ElOceanografo opened this issue May 12, 2020 · 7 comments
Closed

Type instability with Dirichlet distribution #1276

ElOceanografo opened this issue May 12, 2020 · 7 comments

Comments

@ElOceanografo
Copy link
Contributor

Drawing a variable from a Dirichlet distribution introduces a type instability, slowing down sampling dramatically (around 25x on my laptop):

using Turing
@model MarginalizedGMM(x, K, ::Type{T}=Vector{Float64}) where {T} = begin
    N = length(x)
    μ = T(undef, K)
    σ = T(undef, K)
    for i in 1:K
        μ[i] ~ Normal(0, 5)
        σ[i] ~ Gamma()
    end
    w ~ Dirichlet(K, 1.0)
    # w = T([0.75, 0.25]) Way faster with this line instead of ↑
    for i in 1:N
      x[i] ~ Distributions.UnivariateGMM(μ,σ, Categorical(w))
    end
    return::T, σ::T, w::T)
end

x = [randn(150) .- 2; randn(50) .+ 2]
gmm = MarginalizedGMM(x, 2)
varinfo = Turing.VarInfo(gmm)
spl = Turing.SampleFromPrior()
@code_warntype gmm.f(varinfo, spl, Turing.DefaultContext(), gmm)
@devmotion
Copy link
Member

An (additional?) problem might be the type instability of UnivariateGMM, see JuliaStats/Distributions.jl#1123.

@ElOceanografo
Copy link
Contributor Author

Just came across this issue in a different model. I think the UnivariateGMM may have been an additional problem, but it looks like the Dirichlet distribution is still type-unstable:

using Turing, Random

@model function DirichletModel(x)
    w ~ Dirichlet(ones(5))
    x ~ Categorical(w)
end

dm = DirichletModel(4)
varinfo = Turing.VarInfo(dm)
spl = Turing.SampleFromPrior()
@code_warntype(dm.f(Random.GLOBAL_RNG, dm, varinfo, spl, Turing.DefaultContext()))

@bratslavia
Copy link

Ah! I just got bit by this, trying to do the exact same thing (a marginalized mixture model).

@torfjelde
Copy link
Member

It seems the cause of the issue is Bijectors.invlink for the Dirichlet distribution. Super-confused about why though:

julia> d = Dirichlet(ones(5));

julia> x = rand(d);

julia> @code_warntype Bijectors.invlink(d, x)
Variables
  #self#::Core.Const(Bijectors.invlink)
  d::Dirichlet{Float64, Vector{Float64}, Float64}
  y::Vector{Float64}

Body::Any
1%1 = (#self#)(d, y, true)::Any
└──      return %1

But the following, which is the exact impl of invlink, is type-stable:

julia> @code_warntype inv(Bijectors.SimplexBijector{true}())(x)
Variables
  ib::Core.Const(Inverse{Bijectors.SimplexBijector{1, true}, 1}(Bijectors.SimplexBijector{1, true}()))
  y::Vector{Float64}

Body::Vector{Float64}
1%1 = Base.getproperty(ib, :orig)::Core.Const(Bijectors.SimplexBijector{1, true}())
│   %2 = Bijectors._simplex_inv_bijector(y, %1)::Vector{Float64}
└──      return %2

Also weird, link works just fine:

julia> @code_warntype Bijectors.link(d, x)
Variables
  #self#::Core.Const(Bijectors.link)
  d::Dirichlet{Float64, Vector{Float64}, Float64}
  x::Vector{Float64}

Body::Vector{Float64}
1%1 = (#self#)(d, x, true)::Vector{Float64}
└──      return %1

@torfjelde
Copy link
Member

torfjelde commented Apr 4, 2021

Actually, found the offending code: https://github.com/TuringLang/Bijectors.jl/blob/master/src/bijectors/simplex.jl#L6-L12

Currently we have the following:

julia> @code_warntype Bijectors.invlink(d, x, true)
Variables
  #self#::Core.Const(Bijectors.invlink)
  d::Dirichlet{Float64, Vector{Float64}, Float64}
  y::Vector{Float64}
  proj::Bool

Body::Any
1%1 = Core.apply_type(Main.SimplexBijector, proj)::Type%2 = (%1)()::Any%3 = Main.inv(%2)::Any%4 = (%3)(y)::Any
└──      return %4

EDIT: Eeeeh I ran the wrong method below 🙃 But the above is indeed the reason why, but the below isn't a fix:)

If I implement

function Bijectors.invlink(
    d::Dirichlet,
    y::AbstractVecOrMat{<:Real},
    proj::Bool = true
)
    # Hardcoded the dimensionality to 1, thus circumventing
    # the function linked above.
    return inv(SimplexBijector{1, proj}())(y)
end

and we get

julia> @code_warntype Bijectors.link(d, x)
Variables
  #self#::Core.Const(Bijectors.link)
  d::Dirichlet{Float64, Vector{Float64}, Float64}
  x::Vector{Float64}

Body::Vector{Float64}
1%1 = (#self#)(d, x, true)::Vector{Float64}
└──      return %1

as wanted.

@devmotion
Copy link
Member

Alternatively, one could just dispatch on N and replace the function with the fallback and implementations for SimplexBijector{true} and SimplexBijector{false}.

@devmotion
Copy link
Member

Fixed in Bijectors 0.9 and Turing 0.15.13 which should be available soon (JuliaRegistries/General#33714).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants