diff --git a/pymc_experimental/model/marginal_model.py b/pymc_experimental/model/marginal_model.py index ead9a362..06b54829 100644 --- a/pymc_experimental/model/marginal_model.py +++ b/pymc_experimental/model/marginal_model.py @@ -4,16 +4,18 @@ import numpy as np import pymc import pytensor.tensor as pt +import scipy from arviz import InferenceData, dict_to_dataset -from pymc import SymbolicRandomVariable from pymc.backends.arviz import coords_and_dims_for_inferencedata, dataset_to_point_list +from pymc.distributions import MvNormal, SymbolicRandomVariable +from pymc.distributions.continuous import Continuous from pymc.distributions.discrete import Bernoulli, Categorical, DiscreteUniform from pymc.distributions.transforms import Chain from pymc.logprob.abstract import _logprob -from pymc.logprob.basic import conditional_logp, logp +from pymc.logprob.basic import conditional_logp, icdf, logp from pymc.logprob.transforms import IntervalTransform from pymc.model import Model -from pymc.pytensorf import compile_pymc, constant_fold +from pymc.pytensorf import collect_default_updates, compile_pymc, constant_fold from pymc.util import RandomState, _get_seeds_per_chain, treedict from pytensor import Mode, scan from pytensor.compile import SharedVariable @@ -159,17 +161,17 @@ def _marginalize(self, user_warnings=False): f"Cannot marginalize {rv_to_marginalize} due to dependent Potential {pot}" ) - old_rvs, new_rvs = replace_finite_discrete_marginal_subgraph( - fg, rv_to_marginalize, self.basic_RVs + rvs_left_to_marginalize + if isinstance(rv_to_marginalize.owner.op, Continuous): + subgraph_builder_fn = replace_continuous_marginal_subgraph + else: + subgraph_builder_fn = replace_finite_discrete_marginal_subgraph + old_rvs, new_rvs = subgraph_builder_fn( + fg, + rv_to_marginalize, + self.basic_RVs + rvs_left_to_marginalize, + user_warnings=user_warnings, ) - if user_warnings and len(new_rvs) > 2: - warnings.warn( - "There are multiple dependent variables in a FiniteDiscreteMarginalRV. " - f"Their joint logp terms will be assigned to the first RV: {old_rvs[1]}", - UserWarning, - ) - rvs_left_to_marginalize.remove(rv_to_marginalize) for old_rv, new_rv in zip(old_rvs, new_rvs): @@ -267,7 +269,11 @@ def marginalize( ) rv_op = rv_to_marginalize.owner.op - if isinstance(rv_op, DiscreteMarkovChain): + + if isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)): + pass + + elif isinstance(rv_op, DiscreteMarkovChain): if rv_op.n_lags > 1: raise NotImplementedError( "Marginalization for DiscreteMarkovChain with n_lags > 1 is not supported" @@ -276,7 +282,11 @@ def marginalize( raise NotImplementedError( "Marginalization for DiscreteMarkovChain with non-matrix transition probability is not supported" ) - elif not isinstance(rv_op, (Bernoulli, Categorical, DiscreteUniform)): + + elif isinstance(rv_op, Continuous): + pass + + else: raise NotImplementedError( f"Marginalization of RV with distribution {rv_to_marginalize.owner.op} is not supported" ) @@ -449,7 +459,7 @@ def transform_input(inputs): rv_loglike_fn = None joint_logps_norm = log_softmax(joint_logps, axis=-1) if return_samples: - sample_rv_outs = pymc.Categorical.dist(logit_p=joint_logps) + sample_rv_outs = Categorical.dist(logit_p=joint_logps) if isinstance(marginalized_rv.owner.op, DiscreteUniform): sample_rv_outs += rv_domain[0] @@ -549,6 +559,16 @@ class DiscreteMarginalMarkovChainRV(MarginalRV): """Base class for Discrete Marginal Markov Chain RVs""" +class QMCMarginalNormalRV(MarginalRV): + """Basec class for QMC Marginalized RVs""" + + __props__ = ("qmc_order",) + + def __init__(self, *args, qmc_order: int, **kwargs): + self.qmc_order = qmc_order + super().__init__(*args, **kwargs) + + def static_shape_ancestors(vars): """Identify ancestors Shape Ops of static shapes (therefore constant in a valid graph).""" return [ @@ -646,7 +666,9 @@ def collect_shared_vars(outputs, blockers): ] -def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs): +def replace_finite_discrete_marginal_subgraph( + fgraph, rv_to_marginalize, all_rvs, user_warnings: bool = False +): # TODO: This should eventually be integrated in a more general routine that can # identify other types of supported marginalization, of which finite discrete # RVs is just one @@ -655,6 +677,13 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs if not dependent_rvs: raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}") + if user_warnings and len(dependent_rvs) > 1: + warnings.warn( + "There are multiple dependent variables in a FiniteDiscreteMarginalRV. " + f"Their joint logp terms will be assigned to the first RV: {dependent_rvs[0]}", + UserWarning, + ) + ndim_supp = {rv.owner.op.ndim_supp for rv in dependent_rvs} if len(ndim_supp) != 1: raise NotImplementedError( @@ -707,6 +736,39 @@ def replace_finite_discrete_marginal_subgraph(fgraph, rv_to_marginalize, all_rvs return rvs_to_marginalize, marginalized_rvs +def replace_continuous_marginal_subgraph( + fgraph, rv_to_marginalize, all_rvs, user_warnings: bool = False +): + dependent_rvs = find_conditional_dependent_rvs(rv_to_marginalize, all_rvs) + if not dependent_rvs: + raise ValueError(f"No RVs depend on marginalized RV {rv_to_marginalize}") + + marginalized_rv_input_rvs = find_conditional_input_rvs([rv_to_marginalize], all_rvs) + dependent_rvs_input_rvs = [ + rv + for rv in find_conditional_input_rvs(dependent_rvs, all_rvs) + if rv is not rv_to_marginalize + ] + + input_rvs = [*marginalized_rv_input_rvs, *dependent_rvs_input_rvs] + rvs_to_marginalize = [rv_to_marginalize, *dependent_rvs] + + outputs = rvs_to_marginalize + # We are strict about shared variables in SymbolicRandomVariables + inputs = input_rvs + collect_shared_vars(rvs_to_marginalize, blockers=input_rvs) + + # TODO: Assert no non-marginalized variables depend on the rng output of the marginalized variables!!! + marginalized_rvs = QMCMarginalNormalRV( + inputs=inputs, + outputs=[*outputs, *collect_default_updates(inputs=inputs, outputs=outputs).values()], + ndim_supp=max([rv.owner.op.ndim_supp for rv in dependent_rvs]), + qmc_order=13, + )(*inputs)[: len(outputs)] + + fgraph.replace_all(tuple(zip(rvs_to_marginalize, marginalized_rvs))) + return rvs_to_marginalize, marginalized_rvs + + def get_domain_of_finite_discrete_rv(rv: TensorVariable) -> tuple[int, ...]: op = rv.owner.op dist_params = rv.owner.op.dist_params(rv.owner) @@ -870,3 +932,68 @@ def step_alpha(logp_emission, log_alpha, log_P): # return is the joint probability of everything together, but PyMC still expects one logp for each one. dummy_logps = (pt.constant(0),) * (len(values) - 1) return joint_logp, *dummy_logps + + +@_logprob.register(QMCMarginalNormalRV) +def qmc_marginal_rv_logp(op, values, *inputs, **kwargs): + # Clone the inner RV graph of the Marginalized RV + marginalized_rvs_node = op.make_node(*inputs) + # The MarginalizedRV contains the following outputs: + # 1. The variable we marginalized + # 2. The dependent variables + # 3. The updates for the marginalized and dependent variables + marginalized_rv, *inner_rvs_and_updates = clone_replace( + op.inner_outputs, + replace={u: v for u, v in zip(op.inner_inputs, marginalized_rvs_node.inputs)}, + ) + inner_rvs = inner_rvs_and_updates[: (len(inner_rvs_and_updates) - 1) // 2] + + marginalized_vv = marginalized_rv.clone() + marginalized_rv_node = marginalized_rv.owner + marginalized_rv_op = marginalized_rv_node.op + + # GET QMC draws from the marginalized RV + # TODO: Make this an Op + rng = marginalized_rv_op.rng_param(marginalized_rv_node) + shape = constant_fold(tuple(marginalized_rv.shape)) + size = np.prod(shape).astype(int) + n_draws = 2**op.qmc_order + + # TODO: Wrap Sobol in an Op so we can control the RNG and change whenever + qmc_engine = scipy.stats.qmc.Sobol(d=size, seed=rng.get_value(borrow=False)) + uniform_draws = qmc_engine.random(n_draws).reshape((n_draws, *shape)) + + if isinstance(marginalized_rv_op, MvNormal): + # Adapted from https://github.com/scipy/scipy/blob/87c46641a8b3b5b47b81de44c07b840468f7ebe7/scipy/stats/_qmc.py#L2211-L2298 + mean, cov = marginalized_rv_op.dist_params(marginalized_rv_node) + corr_matrix = pt.linalg.cholesky(cov).mT + base_draws = pt.as_tensor(scipy.stats.norm.ppf(0.5 + (1 - 1e-10) * (uniform_draws - 0.5))) + qmc_draws = base_draws @ corr_matrix + mean + else: + qmc_draws = vectorize_graph( + icdf(marginalized_rv, marginalized_vv), + replace={marginalized_vv: uniform_draws}, + ) + + qmc_draws.name = f"QMC_{marginalized_rv_op.name}_draws" + + # Obtain the logp of the dependent variables + # We need to include the marginalized RV for correctness, we remove it later. + inner_rv_values = dict(zip(inner_rvs, values)) + rv_values = inner_rv_values | {marginalized_rv: marginalized_vv} + logps_dict = conditional_logp(rv_values=rv_values, **kwargs) + # Pop the logp term corresponding to the marginalized RV + # (it already got accounted for in the bias of the QMC draws) + logps_dict.pop(marginalized_vv) + + # Vectorize across QMC draws and take the mean on log scale + core_marginalized_logps = list(logps_dict.values()) + batched_marginalized_logps = vectorize_graph( + core_marginalized_logps, replace={marginalized_vv: qmc_draws} + ) + + # Take the mean in log scale + return tuple( + pt.logsumexp(batched_marginalized_logp, axis=0) - pt.log(n_draws) + for batched_marginalized_logp in batched_marginalized_logps + ) diff --git a/pymc_experimental/tests/model/test_marginal_model.py b/pymc_experimental/tests/model/test_marginal_model.py index 31e38615..90c9a927 100644 --- a/pymc_experimental/tests/model/test_marginal_model.py +++ b/pymc_experimental/tests/model/test_marginal_model.py @@ -6,12 +6,14 @@ import pymc as pm import pytensor.tensor as pt import pytest +import scipy from arviz import InferenceData, dict_to_dataset from pymc.distributions import transforms from pymc.logprob.abstract import _logprob from pymc.model.fgraph import fgraph_from_model from pymc.pytensorf import inputvars from pymc.util import UNSET +from pytensor.graph import FunctionGraph from scipy.special import log_softmax, logsumexp from scipy.stats import halfnorm, norm @@ -21,6 +23,7 @@ MarginalModel, is_conditional_dependent, marginalize, + replace_continuous_marginal_subgraph, ) from pymc_experimental.tests.utils import equal_computations_up_to_root @@ -803,3 +806,63 @@ def create_model(model_class): marginal_m.compile_logp()(ip), reference_m.compile_logp()(ip), ) + + +@pytest.mark.parametrize("univariate", (True, False), ids=["univariate", "multivariate"]) +@pytest.mark.parametrize( + "multiple_dependent", (False, True), ids=["single-dependent", "multiple-dependent"] +) +def test_marginalize_normal_qmc(univariate, multiple_dependent): + with MarginalModel() as m: + SD = pm.HalfNormal("SD", default_transform=None) + if univariate: + X = pm.Normal("X", sigma=SD, shape=(3,)) + else: + X = pm.MvNormal("X", mu=[0, 0, 0], cov=np.eye(3) * SD**2) + + if multiple_dependent: + Y = [ + pm.Normal("Y[0]", mu=(2 * X[0] + 1), sigma=1, observed=1), + pm.Normal("Y[1:]", mu=(2 * X[1:] + 1), sigma=1, observed=[2, 3]), + ] + else: + Y = [pm.Normal("Y", mu=(2 * X + 1), sigma=1, observed=[1, 2, 3])] + + m.marginalize([X]) # ideally method="qmc" + + logp_eval = np.hstack(m.compile_logp(vars=Y, sum=False)({"SD": 2.0})) + + np.testing.assert_allclose( + logp_eval, + scipy.stats.norm.logpdf([1, 2, 3], 1, np.sqrt(17)), + rtol=1e-5, + ) + + +def test_marginalize_non_trivial_mvnormal_qmc(): + with MarginalModel() as m: + SD = pm.HalfNormal("SD", default_transform=None) + X = pm.MvNormal("X", cov=[[1.0, 0.5], [0.5, 1.0]] * SD**2) + Y = pm.MvNormal("Y", mu=2 * X + 1, cov=np.eye(2), observed=[1, 2]) + + m.marginalize([X]) + + [logp_eval] = m.compile_logp(vars=Y, sum=False)({"SD": 1}) + + np.testing.assert_allclose( + logp_eval, + scipy.stats.multivariate_normal.logpdf([1, 2], [1, 1], [[5, 2], [2, 5]]), + rtol=1e-5, + ) + + +def test_marginalize_sample(): + with pm.Model() as m: + SD = pm.HalfNormal("SD") + X = pm.Normal.dist(sigma=SD, name="X") + Y = pm.Normal("Y", mu=(2 * X + 1), sigma=1, observed=[1, 2, 3]) + + fg = FunctionGraph(outputs=[SD, Y, X], clone=False) + old_rvs, new_rvs = replace_continuous_marginal_subgraph(fg, X, [Y, SD, X]) + res1, res2 = pm.draw(new_rvs, draws=2) + assert not np.allclose(res1, res2)