Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add closed-form posterior for Gamma-Exponential observation model #133

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 86 additions & 1 deletion aemcmc/conjugates.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
from aesara.graph.rewriting.basic import in2out
from aesara.graph.rewriting.db import LocalGroupDB
from aesara.graph.rewriting.unify import eval_if_etuple
from aesara.tensor.random.basic import BinomialRV, NegBinomialRV, PoissonRV
from aesara.tensor.random.basic import (
BinomialRV,
ExponentialRV,
NegBinomialRV,
PoissonRV,
)
from etuples import etuple, etuplize
from kanren import eq, lall, run
from unification import var
Expand Down Expand Up @@ -238,13 +243,93 @@ def local_beta_negative_binomial_posterior(fgraph, node, srng):
return [(beta_rv, beta_posterior, None)]


def gamma_exponential_conjugateo(
srng: "RandomStream", observed_rv_expr, posterior_expr
):
r"""
Relation for the conjugate posterior of a gamma prior with an exponential observation model.

.. math::

\frac{
Y \sim \operatorname{Exp}\left(\lambda\right), \quad
\lambda \sim \operatorname{Gamma}\left(\alpha, \beta\right)
}{
\left(\lambda|Y=y\right) \sim \operatorname{Gamma}\left(\alpha+1, \beta+y\right)
}

Parameters
----------
srng
The `RandomStream` used to generate the posterior variates.
observed_rv_expr
An expression that represents the observed variable.
posterior_exp
An expression that represents the posterior distribution of the latent
variable.

"""
# Gamma-exponential observation model
alpha_lv, beta_lv = var(), var()
lam_rng_lv = var()
lam_size_lv = var()
lam_type_idx_lv = var()
lam_et = etuple(
etuplize(at.random.gamma),
lam_rng_lv,
lam_size_lv,
lam_type_idx_lv,
alpha_lv,
beta_lv,
)
Y_et = etuple(etuplize(at.random.exponential), var(), var(), var(), lam_et)

# Posterior distribution for lambda
new_alpha_et = etuple(etuplize(at.add), alpha_lv, 1)
new_beta_et = etuple(etuplize(at.add), beta_lv, observed_rv_expr)

lam_posterior_et = etuple(
partial(srng.gen, at.random.gamma),
new_alpha_et,
new_beta_et,
size=lam_size_lv,
dtype=lam_type_idx_lv,
)

return lall(
eq(observed_rv_expr, Y_et),
eq(posterior_expr, lam_posterior_et),
)


@sampler_finder([ExponentialRV])
def local_gamma_exponential_posterior(fgraph, node, srng):
rv_var = node.outputs[1]

q = var()

rv_et = etuplize(rv_var)

res = run(None, q, partial(beta_negative_binomial_conjugateo, srng)(rv_et, q))
res = next(res, None)

if res is None:
return None # pragma: no cover

lam_rv = rv_et[-1].evaled_obj
lam_posterior = eval_if_etuple(res)

return [(lam_rv, lam_posterior, None)]


conjugates_db = LocalGroupDB(apply_all_rewrites=True)
conjugates_db.name = "conjugates_db"
conjugates_db.register("beta_binomial", local_beta_binomial_posterior, "basic")
conjugates_db.register("gamma_poisson", local_gamma_poisson_posterior, "basic")
conjugates_db.register(
"negative_binomial", local_beta_negative_binomial_posterior, "basic"
)
conjugates_db.register("gamma_exponential", local_gamma_exponential_posterior, "basic")


sampler_finder_db.register(
Expand Down
37 changes: 37 additions & 0 deletions tests/test_conjugates.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from aemcmc.conjugates import (
beta_binomial_conjugateo,
beta_negative_binomial_conjugateo,
gamma_exponential_conjugateo,
gamma_poisson_conjugateo,
)

Expand Down Expand Up @@ -142,3 +143,39 @@ def test_beta_negative_binomial_conjugate_expand():
expanded = expanded_expr

assert isinstance(expanded.owner.op, type(at.random.beta))


def test_gamma_exponential_conjugate_contract():
"""Produce the closed-form posterior for the exponential observation model with a gamma prior."""
srng = RandomStream(0)

alpha_tt = at.scalar("alpha")
beta_tt = at.scalar("beta")
lam_rv = srng.gamma(alpha_tt, beta_tt)
Y_rv = srng.exponential(lam_rv)

q_lv = var()
(posterior_expr,) = run(1, q_lv, gamma_exponential_conjugateo(srng, Y_rv, q_lv))
posterior = posterior_expr.evaled_obj

assert isinstance(posterior.owner.op, type(at.random.gamma))


@pytest.mark.xfail(
reason="Op.__call__ does not dispatch to Op.make_node for some RandomVariable and etuple evaluation returns an error"
)
def test_gamma_exponential_conjugate_expand():
"""Expand a contracted gamma-exponential observation model."""

srng = RandomStream(0)

alpha_tt = at.scalar("alpha")
beta_tt = at.scalar("beta")
y_vv = at.iscalar("y")
Y_rv = srng.gamma(alpha_tt + y_vv, beta_tt + 1)

e_lv = var()
(expanded_expr,) = run(1, e_lv, gamma_exponential_conjugateo(srng, e_lv, Y_rv))
expanded = expanded_expr.evaled_obj

assert isinstance(expanded.owner.op, type(at.random.gamma))
Loading