-
Notifications
You must be signed in to change notification settings - Fork 219
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Type instability with Dirichlet distribution #1276
Comments
An (additional?) problem might be the type instability of |
Just came across this issue in a different model. I think the using Turing, Random
@model function DirichletModel(x)
w ~ Dirichlet(ones(5))
x ~ Categorical(w)
end
dm = DirichletModel(4)
varinfo = Turing.VarInfo(dm)
spl = Turing.SampleFromPrior()
@code_warntype(dm.f(Random.GLOBAL_RNG, dm, varinfo, spl, Turing.DefaultContext())) |
Ah! I just got bit by this, trying to do the exact same thing (a marginalized mixture model). |
It seems the cause of the issue is julia> d = Dirichlet(ones(5));
julia> x = rand(d);
julia> @code_warntype Bijectors.invlink(d, x)
Variables
#self#::Core.Const(Bijectors.invlink)
d::Dirichlet{Float64, Vector{Float64}, Float64}
y::Vector{Float64}
Body::Any
1 ─ %1 = (#self#)(d, y, true)::Any
└── return %1 But the following, which is the exact impl of julia> @code_warntype inv(Bijectors.SimplexBijector{true}())(x)
Variables
ib::Core.Const(Inverse{Bijectors.SimplexBijector{1, true}, 1}(Bijectors.SimplexBijector{1, true}()))
y::Vector{Float64}
Body::Vector{Float64}
1 ─ %1 = Base.getproperty(ib, :orig)::Core.Const(Bijectors.SimplexBijector{1, true}())
│ %2 = Bijectors._simplex_inv_bijector(y, %1)::Vector{Float64}
└── return %2 Also weird, julia> @code_warntype Bijectors.link(d, x)
Variables
#self#::Core.Const(Bijectors.link)
d::Dirichlet{Float64, Vector{Float64}, Float64}
x::Vector{Float64}
Body::Vector{Float64}
1 ─ %1 = (#self#)(d, x, true)::Vector{Float64}
└── return %1 |
Actually, found the offending code: https://github.com/TuringLang/Bijectors.jl/blob/master/src/bijectors/simplex.jl#L6-L12 Currently we have the following: julia> @code_warntype Bijectors.invlink(d, x, true)
Variables
#self#::Core.Const(Bijectors.invlink)
d::Dirichlet{Float64, Vector{Float64}, Float64}
y::Vector{Float64}
proj::Bool
Body::Any
1 ─ %1 = Core.apply_type(Main.SimplexBijector, proj)::Type
│ %2 = (%1)()::Any
│ %3 = Main.inv(%2)::Any
│ %4 = (%3)(y)::Any
└── return %4 EDIT: Eeeeh I ran the wrong method below 🙃 But the above is indeed the reason why, but the below isn't a fix:) If I implement function Bijectors.invlink(
d::Dirichlet,
y::AbstractVecOrMat{<:Real},
proj::Bool = true
)
# Hardcoded the dimensionality to 1, thus circumventing
# the function linked above.
return inv(SimplexBijector{1, proj}())(y)
end and we get julia> @code_warntype Bijectors.link(d, x)
Variables
#self#::Core.Const(Bijectors.link)
d::Dirichlet{Float64, Vector{Float64}, Float64}
x::Vector{Float64}
Body::Vector{Float64}
1 ─ %1 = (#self#)(d, x, true)::Vector{Float64}
└── return %1 as wanted. |
Alternatively, one could just dispatch on |
Fixed in Bijectors 0.9 and Turing 0.15.13 which should be available soon (JuliaRegistries/General#33714). |
Drawing a variable from a Dirichlet distribution introduces a type instability, slowing down sampling dramatically (around 25x on my laptop):
The text was updated successfully, but these errors were encountered: