Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Marginalize continuous variable via QMC integration #353

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 143 additions & 16 deletions pymc_experimental/model/marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,18 @@
import numpy as np
import pymc
import pytensor.tensor as pt
import scipy
from arviz import InferenceData, dict_to_dataset
from pymc import SymbolicRandomVariable
from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_point_list
from pymc.distributions import MvNormal, SymbolicRandomVariable
from pymc.distributions.continuous import Continuous
from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform
from pymc.distributions.transforms import Chain
from pymc.logprob.abstract import _logprob
from pymc.logprob.basic import conditional_logp, logp
from pymc.logprob.basic import conditional_logp, icdf, logp
from pymc.logprob.transforms import IntervalTransform
from pymc.model import Model
from pymc.pytensorf import compile_pymc, constant_fold
from pymc.pytensorf import collect_default_updates, compile_pymc, constant_fold
from pymc.util import RandomState, _get_seeds_per_chain, treedict
from pytensor import Mode, scan
from pytensor.compile import SharedVariable
Expand Down Expand Up @@ -159,17 +161,17 @@ def _marginalize(self, user_warnings=False):
f"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}"
)

old_rvs, new_rvs = replace_finite_discrete_marginal_subgraph(
fg, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize
if isinstance(rv_to_marginalize.owner.op, Continuous):
subgraph_builder_fn = replace_continuous_marginal_subgraph
else:
subgraph_builder_fn = replace_finite_discrete_marginal_subgraph
old_rvs, new_rvs = subgraph_builder_fn(
fg,
rv_to_marginalize,
self.basic_RVs + rvs_left_to_marginalize,
user_warnings=user_warnings,
)

if user_warnings and len(new_rvs) > 2:
warnings.warn(
"There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
f"Their joint logp terms will be assigned to the first RV: {old_rvs[1]}",
UserWarning,
)

rvs_left_to_marginalize.remove(rv_to_marginalize)

for old_rv, new_rv in zip(old_rvs, new_rvs):
Expand Down Expand Up @@ -267,7 +269,11 @@ def marginalize(
)

rv_op = rv_to_marginalize.owner.op
if isinstance(rv_op, DiscreteMarkovChain):

if isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)):
pass

elif isinstance(rv_op, DiscreteMarkovChain):
if rv_op.n_lags > 1:
raise NotImplementedError(
"Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported"
Expand All @@ -276,7 +282,11 @@ def marginalize(
raise NotImplementedError(
"Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported"
)
elif not isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)):

elif isinstance(rv_op, Continuous):
pass

else:
raise NotImplementedError(
f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported"
)
Expand Down Expand Up @@ -449,7 +459,7 @@ def transform_input(inputs):
rv_loglike_fn = None
joint_logps_norm = log_softmax(joint_logps, axis=-1)
if return_samples:
sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps)
sample_rv_outs = Categorical.dist(logit_p=joint_logps)
if isinstance(marginalized_rv.owner.op, DiscreteUniform):
sample_rv_outs += rv_domain[0]

Expand Down Expand Up @@ -549,6 +559,16 @@ class DiscreteMarginalMarkovChainRV(MarginalRV):
"""Base class for Discrete Marginal Markov Chain RVs"""


class QMCMarginalNormalRV(MarginalRV):
"""Basec class for QMC Marginalized RVs"""

__props__ = ("qmc_order",)

def __init__(self, *args, qmc_order: int, **kwargs):
self.qmc_order = qmc_order
super().__init__(*args, **kwargs)


def static_shape_ancestors(vars):
"""Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph)."""
return [
Expand Down Expand Up @@ -646,7 +666,9 @@ def collect_shared_vars(outputs, blockers):
]


def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs):
def replace_finite_discrete_marginal_subgraph(
fgraph, rv_to_marginalize, all_rvs, user_warnings: bool = False
):
# TODO: This should eventually be integrated in a more general routine that can
# identify other types of supported marginalization, of which finite discrete
# RVs is just one
Expand All @@ -655,6 +677,13 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
if not dependent_rvs:
raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")

if user_warnings and len(dependent_rvs) > 1:
warnings.warn(
"There are multiple dependent variables in a FiniteDiscreteMarginalRV. "
f"Their joint logp terms will be assigned to the first RV: {dependent_rvs[0]}",
UserWarning,
)

ndim_supp = {rv.owner.op.ndim_supp for rv in dependent_rvs}
if len(ndim_supp) != 1:
raise NotImplementedError(
Expand Down Expand Up @@ -707,6 +736,39 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs
return rvs_to_marginalize, marginalized_rvs


def replace_continuous_marginal_subgraph(
fgraph, rv_to_marginalize, all_rvs, user_warnings: bool = False
):
dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs)
if not dependent_rvs:
raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}")

marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs)
dependent_rvs_input_rvs = [
rv
for rv in find_conditional_input_rvs(dependent_rvs, all_rvs)
if rv is not rv_to_marginalize
]

input_rvs = [*marginalized_rv_input_rvs, *dependent_rvs_input_rvs]
rvs_to_marginalize = [rv_to_marginalize, *dependent_rvs]

outputs = rvs_to_marginalize
# We are strict about shared variables in SymbolicRandomVariables
inputs = input_rvs + collect_shared_vars(rvs_to_marginalize, blockers=input_rvs)

# TODO: Assert no non-marginalized variables depend on the rng output of the marginalized variables!!!
marginalized_rvs = QMCMarginalNormalRV(
inputs=inputs,
outputs=[*outputs, *collect_default_updates(inputs=inputs, outputs=outputs).values()],
ndim_supp=max([rv.owner.op.ndim_supp for rv in dependent_rvs]),
qmc_order=13,
)(*inputs)[: len(outputs)]

fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs)))
return rvs_to_marginalize, marginalized_rvs


def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]:
op = rv.owner.op
dist_params = rv.owner.op.dist_params(rv.owner)
Expand Down Expand Up @@ -870,3 +932,68 @@ def step_alpha(logp_emission, log_alpha, log_P):
# return is the joint probability of everything together, but PyMC still expects one logp for each one.
dummy_logps = (pt.constant(0),) * (len(values) - 1)
return joint_logp, *dummy_logps


@_logprob.register(QMCMarginalNormalRV)
def qmc_marginal_rv_logp(op, values, *inputs, **kwargs):
# Clone the inner RV graph of the Marginalized RV
marginalized_rvs_node = op.make_node(*inputs)
# The MarginalizedRV contains the following outputs:
# 1. The variable we marginalized
# 2. The dependent variables
# 3. The updates for the marginalized and dependent variables
marginalized_rv, *inner_rvs_and_updates = clone_replace(
op.inner_outputs,
replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)},
)
inner_rvs = inner_rvs_and_updates[: (len(inner_rvs_and_updates) - 1) // 2]

marginalized_vv = marginalized_rv.clone()
marginalized_rv_node = marginalized_rv.owner
marginalized_rv_op = marginalized_rv_node.op

# GET QMC draws from the marginalized RV
# TODO: Make this an Op
rng = marginalized_rv_op.rng_param(marginalized_rv_node)
shape = constant_fold(tuple(marginalized_rv.shape))
size = np.prod(shape).astype(int)
n_draws = 2**op.qmc_order

# TODO: Wrap Sobol in an Op so we can control the RNG and change whenever
qmc_engine = scipy.stats.qmc.Sobol(d=size, seed=rng.get_value(borrow=False))
uniform_draws = qmc_engine.random(n_draws).reshape((n_draws, *shape))

if isinstance(marginalized_rv_op, MvNormal):
ferrine marked this conversation as resolved.
Show resolved Hide resolved
# Adapted from https://github.com/scipy/scipy/blob/87c46641a8b3b5b47b81de44c07b840468f7ebe7/scipy/stats/_qmc.py#L2211-L2298
mean, cov = marginalized_rv_op.dist_params(marginalized_rv_node)
corr_matrix = pt.linalg.cholesky(cov).mT
base_draws = pt.as_tensor(scipy.stats.norm.ppf(0.5 + (1 - 1e-10) * (uniform_draws - 0.5)))
qmc_draws = base_draws @ corr_matrix + mean
else:
qmc_draws = vectorize_graph(
icdf(marginalized_rv, marginalized_vv),
replace={marginalized_vv: uniform_draws},
)

qmc_draws.name = f"QMC_{marginalized_rv_op.name}_draws"

# Obtain the logp of the dependent variables
# We need to include the marginalized RV for correctness, we remove it later.
inner_rv_values = dict(zip(inner_rvs, values))
rv_values = inner_rv_values | {marginalized_rv: marginalized_vv}
logps_dict = conditional_logp(rv_values=rv_values, **kwargs)
# Pop the logp term corresponding to the marginalized RV
# (it already got accounted for in the bias of the QMC draws)
logps_dict.pop(marginalized_vv)

# Vectorize across QMC draws and take the mean on log scale
core_marginalized_logps = list(logps_dict.values())
batched_marginalized_logps = vectorize_graph(
core_marginalized_logps, replace={marginalized_vv: qmc_draws}
)

# Take the mean in log scale
return tuple(
pt.logsumexp(batched_marginalized_logp, axis=0) - pt.log(n_draws)
for batched_marginalized_logp in batched_marginalized_logps
)
63 changes: 63 additions & 0 deletions pymc_experimental/tests/model/test_marginal_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,14 @@
import pymc as pm
import pytensor.tensor as pt
import pytest
import scipy
from arviz import InferenceData, dict_to_dataset
from pymc.distributions import transforms
from pymc.logprob.abstract import _logprob
from pymc.model.fgraph import fgraph_from_model
from pymc.pytensorf import inputvars
from pymc.util import UNSET
from pytensor.graph import FunctionGraph
from scipy.special import log_softmax, logsumexp
from scipy.stats import halfnorm, norm

Expand All @@ -21,6 +23,7 @@
MarginalModel,
is_conditional_dependent,
marginalize,
replace_continuous_marginal_subgraph,
)
from pymc_experimental.tests.utils import equal_computations_up_to_root

Expand Down Expand Up @@ -803,3 +806,63 @@ def create_model(model_class):
marginal_m.compile_logp()(ip),
reference_m.compile_logp()(ip),
)


@pytest.mark.parametrize("univariate", (True, False), ids=["univariate", "multivariate"])
@pytest.mark.parametrize(
"multiple_dependent", (False, True), ids=["single-dependent", "multiple-dependent"]
)
def test_marginalize_normal_qmc(univariate, multiple_dependent):
with MarginalModel() as m:
SD = pm.HalfNormal("SD", default_transform=None)
if univariate:
X = pm.Normal("X", sigma=SD, shape=(3,))
else:
X = pm.MvNormal("X", mu=[0, 0, 0], cov=np.eye(3) * SD**2)

if multiple_dependent:
Y = [
pm.Normal("Y[0]", mu=(2 * X[0] + 1), sigma=1, observed=1),
pm.Normal("Y[1:]", mu=(2 * X[1:] + 1), sigma=1, observed=[2, 3]),
]
else:
Y = [pm.Normal("Y", mu=(2 * X + 1), sigma=1, observed=[1, 2, 3])]

m.marginalize([X]) # ideally method="qmc"

logp_eval = np.hstack(m.compile_logp(vars=Y, sum=False)({"SD": 2.0}))

np.testing.assert_allclose(
logp_eval,
scipy.stats.norm.logpdf([1, 2, 3], 1, np.sqrt(17)),
rtol=1e-5,
)


def test_marginalize_non_trivial_mvnormal_qmc():
with MarginalModel() as m:
SD = pm.HalfNormal("SD", default_transform=None)
X = pm.MvNormal("X", cov=[[1.0, 0.5], [0.5, 1.0]] * SD**2)
Y = pm.MvNormal("Y", mu=2 * X + 1, cov=np.eye(2), observed=[1, 2])

m.marginalize([X])

[logp_eval] = m.compile_logp(vars=Y, sum=False)({"SD": 1})

np.testing.assert_allclose(
logp_eval,
scipy.stats.multivariate_normal.logpdf([1, 2], [1, 1], [[5, 2], [2, 5]]),
rtol=1e-5,
)


def test_marginalize_sample():
with pm.Model() as m:
SD = pm.HalfNormal("SD")
X = pm.Normal.dist(sigma=SD, name="X")
Y = pm.Normal("Y", mu=(2 * X + 1), sigma=1, observed=[1, 2, 3])

fg = FunctionGraph(outputs=[SD, Y, X], clone=False)
old_rvs, new_rvs = replace_continuous_marginal_subgraph(fg, X, [Y, SD, X])
res1, res2 = pm.draw(new_rvs, draws=2)
assert not np.allclose(res1, res2)
Loading