Skip to content

Commit

Permalink
Add logp parametrization for Categorical
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Aug 26, 2021
1 parent 28b5539 commit 2f7e03c
Showing 1 changed file with 14 additions and 1 deletion.
15 changes: 14 additions & 1 deletion src/parameterized/categorical.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,22 @@ ncategories(d::Categorical) = length(d.p)

logdensity(d::Categorical{(:p)}, y) = log(d.p[y])

# Very inefficient because of the heavy implementation of Dists.DiscreteNonParametric
# The implementation of Dists.DiscreteNonParametric has heavy argument checks
# But I think since the values of Categorical are 1:n the sortperm has no effect
# So it might be OK
distproxy(d::Categorical{(:p)}) = Dists.Categorical(d.p)

Base.rand(rng::AbstractRNG, T::Type, d::Categorical{(:p)}) = rand(rng, distproxy(d))

asparams(::Type{<:Categorical}, ::Val{:p}) = as𝕀

###############################################################################
@kwstruct Categorical(logp)

logdensity(d::Categorical{(:logp)}, y) = d.logp[y]

distproxy(d::Categorical{(:logp)}) = Dists.Categorical(exp.(d.logp)) # inefficient

Base.rand(rng::AbstractRNG, T::Type, d::Categorical{(:logp)}) = rand(rng, distproxy(d))

asparams(::Type{<:Categorical}, ::Val{:logp}) = asℝ

0 comments on commit 2f7e03c

Please sign in to comment.