From 2f7e03c39e85414cff13e20bb0bb9b8e45fc92a7 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Thu, 26 Aug 2021 09:43:50 +0200 Subject: [PATCH] Add logp parametrization for Categorical --- src/parameterized/categorical.jl | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/src/parameterized/categorical.jl b/src/parameterized/categorical.jl index 4e2abc4d..f4ccde37 100644 --- a/src/parameterized/categorical.jl +++ b/src/parameterized/categorical.jl @@ -19,9 +19,22 @@ ncategories(d::Categorical) = length(d.p) logdensity(d::Categorical{(:p)}, y) = log(d.p[y]) -# Very inefficient because of the heavy implementation of Dists.DiscreteNonParametric +# The implementation of Dists.DiscreteNonParametric has heavy argument checks +# But I think since the values of Categorical are 1:n the sortperm has no effect +# So it might be OK distproxy(d::Categorical{(:p)}) = Dists.Categorical(d.p) Base.rand(rng::AbstractRNG, T::Type, d::Categorical{(:p)}) = rand(rng, distproxy(d)) asparams(::Type{<:Categorical}, ::Val{:p}) = as𝕀 + +############################################################################### +@kwstruct Categorical(logp) + +logdensity(d::Categorical{(:logp)}, y) = d.logp[y] + +distproxy(d::Categorical{(:logp)}) = Dists.Categorical(exp.(d.logp)) # inefficient + +Base.rand(rng::AbstractRNG, T::Type, d::Categorical{(:logp)}) = rand(rng, distproxy(d)) + +asparams(::Type{<:Categorical}, ::Val{:logp}) = asℝ