Skip to content

Commit

Permalink
Implement Mixture logcdf
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Mar 11, 2022
1 parent 2d65929 commit 7a80d49
Showing 1 changed file with 34 additions and 2 deletions.
36 changes: 34 additions & 2 deletions pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import numpy as np

from aeppl.abstract import MeasurableVariable, _get_measurable_outputs
from aeppl.logprob import _logprob
from aeppl.logprob import _logcdf, _logprob
from aesara.compile.builders import OpFromGraph
from aesara.tensor import TensorVariable
from aesara.tensor.random.op import RandomVariable
Expand All @@ -33,7 +33,7 @@
_get_moment,
get_moment,
)
from pymc.distributions.logprob import logp
from pymc.distributions.logprob import logcdf, logp
from pymc.distributions.shape_utils import to_tuple
from pymc.util import check_dist_not_registered

Expand Down Expand Up @@ -404,6 +404,38 @@ def marginal_mixture_logprob(op, values, rng, weights, *components, **kwargs):
return mix_logp


@_logcdf.register(MarginalMixtureRV)
def marginal_mixture_logcdf(op, value, rng, weights, *components, **kwargs):

# single component
if len(components) == 1:
# Need to broadcast value across mixture axis
mix_axis = -components[0].owner.op.ndim_supp - 1
components_logcdf = logcdf(components[0], at.expand_dims(value, mix_axis))
else:
components_logcdf = at.stack(
[logcdf(component, value) for component in components],
axis=-1,
)

mix_logcdf = at.logsumexp(at.log(weights) + components_logcdf, axis=-1)

# Squeeze stack dimension
# There is a Aeasara bug in squeeze with negative axis
# mix_logp = at.squeeze(mix_logp, axis=-1)
mix_logcdf = at.squeeze(mix_logcdf, axis=mix_logcdf.ndim - 1)

mix_logcdf = check_parameters(
mix_logcdf,
0 <= weights,
weights <= 1,
at.isclose(at.sum(weights, axis=-1), 1),
msg="0 <= weights <= 1, sum(weights) == 1",
)

return mix_logcdf


@_get_moment.register(MarginalMixtureRV)
def get_moment_marginal_mixture(op, rv, rng, weights, *components):
ndim_supp = components[0].owner.op.ndim_supp
Expand Down

0 comments on commit 7a80d49

Please sign in to comment.