From 09a3e0bde3cf12fd0f8a63a3ddf4a32e316ec7c9 Mon Sep 17 00:00:00 2001 From: Yerdos Ordabayev <50752571+ordabayevy@users.noreply.github.com> Date: Sun, 22 Jan 2023 22:07:57 -0500 Subject: [PATCH] Funsor based `TraceEnum_ELBO` implementation (#1512) * initial commit * wip * remove TraceMarkovEnum_ELBO * pair coded * clean up * Make enum example work * port tests from pyro to numpyro * Add missing test file * traceenum_elbo2 * pair coded * pass more tests * pass all tests * organize * lint * fixes * test_gradient * fix TraceGraph_ELBO * lint * Revert tracegraph_elbo changes * Address masked distribution * revert changes at replay messenger * refactor * clean * add validations * fix comments * lint * fix validation * fix * fix enum_vars * rm wordclouds.png * address comments Co-authored-by: Du Phan --- numpyro/contrib/funsor/enum_messenger.py | 6 + numpyro/distributions/kl.py | 9 + numpyro/handlers.py | 7 + numpyro/infer/__init__.py | 2 + numpyro/infer/elbo.py | 303 ++- numpyro/util.py | 29 +- test/contrib/test_enum_elbo.py | 2257 ++++++++++++++++++++++ test/infer/test_gradient.py | 139 ++ 8 files changed, 2745 insertions(+), 7 deletions(-) create mode 100644 test/contrib/test_enum_elbo.py create mode 100644 test/infer/test_gradient.py diff --git a/numpyro/contrib/funsor/enum_messenger.py b/numpyro/contrib/funsor/enum_messenger.py index af4bb61d1..e8e541948 100644 --- a/numpyro/contrib/funsor/enum_messenger.py +++ b/numpyro/contrib/funsor/enum_messenger.py @@ -521,6 +521,12 @@ def _get_batch_shape(cond_indep_stack): def process_message(self, msg): if msg["type"] in ["to_funsor", "to_data"]: return super().process_message(msg) + if msg["type"] == "sample" and self.size != self.subsample_size: + plate_to_scale = msg.setdefault("plate_to_scale", {}) + assert self.name not in plate_to_scale + plate_to_scale[self.name] = ( + self.size / self.subsample_size if self.subsample_size else 1 + ) return OrigPlateMessenger.process_message(self, msg) def postprocess_message(self, msg): diff --git a/numpyro/distributions/kl.py b/numpyro/distributions/kl.py index d82b6c59b..2fcfecb2f 100644 --- a/numpyro/distributions/kl.py +++ b/numpyro/distributions/kl.py @@ -39,6 +39,7 @@ Normal, Weibull, ) +from numpyro.distributions.discrete import CategoricalProbs from numpyro.distributions.distribution import ( Delta, Distribution, @@ -146,6 +147,14 @@ def kl_divergence(p, q): return t1 - t2 + t3 +@dispatch(CategoricalProbs, CategoricalProbs) +def kl_divergence(p, q): + t = p.probs * (p.logits - q.logits) + t = jnp.where(q.probs == 0, jnp.inf, t) + t = jnp.where(p.probs == 0, 0.0, t) + return t.sum(-1) + + @dispatch(Dirichlet, Dirichlet) def kl_divergence(p, q): # From http://bariskurt.com/kullback-leibler-divergence-between-two-dirichlet-and-beta-distributions/ diff --git a/numpyro/handlers.py b/numpyro/handlers.py index a63ddb33b..5fb1e24d5 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -609,6 +609,13 @@ def process_message(self, msg): msg["scale"] = ( self.scale if msg.get("scale") is None else self.scale * msg["scale"] ) + plate_to_scale = msg.setdefault("plate_to_scale", {}) + scale = ( + self.scale + if plate_to_scale.get(None) is None + else self.scale * plate_to_scale[None] + ) + plate_to_scale[None] = scale class scope(Messenger): diff --git a/numpyro/infer/__init__.py b/numpyro/infer/__init__.py index b26ec5f92..d21d77ce4 100644 --- a/numpyro/infer/__init__.py +++ b/numpyro/infer/__init__.py @@ -6,6 +6,7 @@ ELBO, RenyiELBO, Trace_ELBO, + TraceEnum_ELBO, TraceGraph_ELBO, TraceMeanField_ELBO, ) @@ -49,6 +50,7 @@ "SA", "SVI", "Trace_ELBO", + "TraceEnum_ELBO", "TraceGraph_ELBO", "TraceMeanField_ELBO", ] diff --git a/numpyro/infer/elbo.py b/numpyro/infer/elbo.py index 668f2acca..7e1ebfab0 100644 --- a/numpyro/infer/elbo.py +++ b/numpyro/infer/elbo.py @@ -1,8 +1,8 @@ # Copyright Contributors to the Pyro project. # SPDX-License-Identifier: Apache-2.0 -from collections import defaultdict -from functools import partial +from collections import OrderedDict, defaultdict +from functools import partial, reduce from operator import itemgetter import warnings @@ -11,10 +11,16 @@ import jax.numpy as jnp from jax.scipy.special import logsumexp +from numpyro.distributions import ExpandedDistribution, MaskedDistribution from numpyro.distributions.kl import kl_divergence from numpyro.distributions.util import scale_and_mask from numpyro.handlers import Messenger, replay, seed, substitute, trace -from numpyro.infer.util import get_importance_trace, log_density +from numpyro.infer.util import ( + _without_rsample_stop_gradient, + get_importance_trace, + is_identically_one, + log_density, +) from numpyro.ops.provenance import eval_provenance, get_provenance from numpyro.util import _validate_model, check_model_guide_match, find_stack_level @@ -710,3 +716,294 @@ def single_particle_elbo(rng_key): else: rng_keys = random.split(rng_key, self.num_particles) return -jnp.mean(vmap(single_particle_elbo)(rng_keys)) + + +def get_importance_trace_enum(model, guide, args, kwargs, params, max_plate_nesting): + """ + (EXPERIMENTAL) Returns traces from the enumerated guide and the enumerated model that is run against it. + The returned traces also store the log probability at each site and the log measure for measure vars. + """ + import funsor + from numpyro.contrib.funsor import ( + enum, + plate_to_enum_plate, + to_funsor, + trace as _trace, + ) + + with plate_to_enum_plate(), enum( + first_available_dim=(-max_plate_nesting - 1) if max_plate_nesting else None + ): + guide = substitute(guide, data=params) + with _without_rsample_stop_gradient(): + guide_trace = _trace(guide).get_trace(*args, **kwargs) + model = substitute(replay(model, guide_trace), data=params) + model_trace = _trace(model).get_trace(*args, **kwargs) + guide_trace = { + name: site for name, site in guide_trace.items() if site["type"] == "sample" + } + model_trace = { + name: site for name, site in model_trace.items() if site["type"] == "sample" + } + for is_model, tr in zip((False, True), (guide_trace, model_trace)): + for name, site in tr.items(): + if is_model and (site["is_observed"] or (site["name"] in guide_trace)): + site["is_measure"] = False + if "log_prob" not in site: + value = site["value"] + intermediates = site["intermediates"] + if intermediates: + log_prob = site["fn"].log_prob(value, intermediates) + else: + log_prob = site["fn"].log_prob(value) + + dim_to_name = site["infer"]["dim_to_name"] + site["log_prob"] = to_funsor( + log_prob, output=funsor.Real, dim_to_name=dim_to_name + ) + if site.get("is_measure", True): + # get rid off masking + base_fn = site["fn"] + batch_shape = base_fn.batch_shape + while isinstance( + base_fn, (MaskedDistribution, ExpandedDistribution) + ): + base_fn = base_fn.base_dist + base_fn = base_fn.expand(batch_shape) + if intermediates: + log_measure = base_fn.log_prob(value, intermediates) + else: + log_measure = base_fn.log_prob(value) + # dice factor + if not site["infer"].get("enumerate") == "parallel": + log_measure = log_measure - funsor.ops.detach(log_measure) + site["log_measure"] = to_funsor( + log_measure, output=funsor.Real, dim_to_name=dim_to_name + ) + return model_trace, guide_trace + + +def _partition(model_sum_deps, sum_vars): + # Construct a bipartite graph between model_sum_deps and the sum_vars + neighbors = OrderedDict([(t, []) for t in model_sum_deps.keys()]) + for key, deps in model_sum_deps.items(): + for dim in deps: + if dim in sum_vars: + neighbors[key].append(dim) + neighbors.setdefault(dim, []).append(key) + + # Partition the bipartite graph into connected components for contraction. + components = [] + while neighbors: + v, pending = neighbors.popitem() + component = OrderedDict([(v, None)]) # used as an OrderedSet + for v in pending: + component[v] = None + while pending: + v = pending.pop() + for v in neighbors.pop(v): + if v not in component: + component[v] = None + pending.append(v) + + # Split this connected component into factors and measures. + # Append only if component_factors is non-empty + component_factors = frozenset(v for v in component if v not in sum_vars) + if component_factors: + component_measures = frozenset(v for v in component if v in sum_vars) + components.append((component_factors, component_measures)) + return components + + +class TraceEnum_ELBO(ELBO): + """ + A TraceEnum implementation of ELBO-based SVI. The gradient estimator + is constructed along the lines of reference [1] specialized to the case + of the ELBO. It supports arbitrary dependency structure for the model + and guide. + Fine-grained conditional dependency information as recorded in the + trace is used to reduce the variance of the gradient estimator. + In particular provenance tracking [2] is used to find the ``cost`` terms + that depend on each non-reparameterizable sample site. + Enumerated variables are eliminated using the TVE algorithm for plated + factor graphs [3]. + + References + + [1] `Storchastic: A Framework for General Stochastic Automatic Differentiation`, + Emile van Kriekenc, Jakub M. Tomczak, Annette ten Teije + + [2] `Nonstandard Interpretations of Probabilistic Programs for Efficient Inference`, + David Wingate, Noah Goodman, Andreas Stuhlmüller, Jeffrey Siskind + + [3] `Tensor Variable Elimination for Plated Factor Graphs`, + Fritz Obermeyer, Eli Bingham, Martin Jankowiak, Justin Chiu, + Neeraj Pradhan, Alexander M. Rush, Noah Goodman + """ + + can_infer_discrete = True + + def __init__(self, num_particles=1, max_plate_nesting=float("inf")): + if max_plate_nesting == float("inf"): + raise ValueError( + "Currently, we require `max_plate_nesting` to be a non-positive integer." + ) + self.max_plate_nesting = max_plate_nesting + super().__init__(num_particles=num_particles) + + def loss(self, rng_key, param_map, model, guide, *args, **kwargs): + def single_particle_elbo(rng_key): + import funsor + from numpyro.contrib.funsor import to_data, to_funsor + + model_seed, guide_seed = random.split(rng_key) + seeded_model = seed(model, model_seed) + seeded_guide = seed(guide, guide_seed) + + model_trace, guide_trace = get_importance_trace_enum( + seeded_model, + seeded_guide, + args, + kwargs, + param_map, + self.max_plate_nesting, + ) + check_model_guide_match(model_trace, guide_trace) + _validate_model(model_trace, plate_warning="strict") + + # Find dependencies on non-reparameterizable sample sites for + # each cost term in the model and the guide. + model_deps, guide_deps = get_provenance( + eval_provenance( + partial( + track_nonreparam(get_importance_log_probs), + seeded_model, + seeded_guide, + args, + kwargs, + param_map, + ) + ) + ) + + sum_vars = frozenset( + [ + name + for name, site in model_trace.items() + if site.get("is_measure", True) + ] + ) + model_sum_deps = { + k: v & sum_vars for k, v in model_deps.items() if k not in sum_vars + } + model_deps = { + k: v - sum_vars for k, v in model_deps.items() if k not in sum_vars + } + + elbo = 0.0 + for group_names, group_sum_vars in _partition(model_sum_deps, sum_vars): + if not group_sum_vars: + # uncontracted logp cost term + assert len(group_names) == 1 + name = next(iter(group_names)) + cost = model_trace[name]["log_prob"] + scale = model_trace[name]["scale"] + deps = model_deps[name] + dice_factors = [guide_trace[key]["log_measure"] for key in deps] + else: + # compute contracted cost term + group_factors = tuple( + model_trace[name]["log_prob"] for name in group_names + ) + group_factors += tuple( + model_trace[var]["log_measure"] for var in group_sum_vars + ) + group_factor_vars = frozenset().union( + *[f.inputs for f in group_factors] + ) + group_plates = group_factor_vars - frozenset(model_trace.keys()) + outermost_plates = frozenset.intersection( + *(frozenset(f.inputs) & group_plates for f in group_factors) + ) + elim_plates = group_plates - outermost_plates + cost = funsor.sum_product.sum_product( + funsor.ops.logaddexp, + funsor.ops.add, + group_factors, + plates=group_plates, + eliminate=group_sum_vars | elim_plates, + ) + # incorporate the effects of subsampling and handlers.scale through a common scale factor + group_scales = {} + for name in group_names: + for plate, value in ( + model_trace[name].get("plate_to_scale", {}).items() + ): + if plate in group_scales: + if value != group_scales[plate]: + raise ValueError( + "Expected all enumerated sample sites to share a common scale factor, " + f"but found different scales at plate('{plate}')." + ) + else: + group_scales[plate] = value + scale = ( + reduce(lambda a, b: a * b, group_scales.values()) + if group_scales + else None + ) + # combine deps + deps = frozenset().union( + *[model_deps[name] for name in group_names] + ) + # check model guide enumeration constraint + for key in deps: + site = guide_trace[key] + if site["infer"].get("enumerate") == "parallel": + for plate in ( + frozenset(site["log_measure"].inputs) & elim_plates + ): + raise ValueError( + "Expected model enumeration to be no more global than guide enumeration, but found " + f"model enumeration sites upstream of guide site '{key}' in plate('{plate}')." + "Try converting some model enumeration sites to guide enumeration sites." + ) + # combine dice factors + dice_factors = [ + guide_trace[key]["log_measure"].reduce( + funsor.ops.add, + frozenset(guide_trace[key]["log_measure"].inputs) + & elim_plates, + ) + for key in deps + ] + + if dice_factors: + dice_factor = reduce(lambda a, b: a + b, dice_factors) + cost = cost * funsor.ops.exp(dice_factor) + if (scale is not None) and (not is_identically_one(scale)): + cost = cost * to_funsor(scale) + + elbo = elbo + cost.reduce(funsor.ops.add) + + for name, deps in guide_deps.items(): + # -logq cost term + cost = -guide_trace[name]["log_prob"] + scale = guide_trace[name]["scale"] + dice_factors = [guide_trace[key]["log_measure"] for key in deps] + if dice_factors: + dice_factor = reduce(lambda a, b: a + b, dice_factors) + cost = cost * funsor.ops.exp(dice_factor) + if (scale is not None) and (not is_identically_one(scale)): + cost = cost * to_funsor(scale) + elbo = elbo + cost.reduce(funsor.ops.add) + + return to_data(elbo) + + # Return (-elbo) since by convention we do gradient descent on a loss and + # the ELBO is a lower bound that needs to be maximized. + if self.num_particles == 1: + return -single_particle_elbo(rng_key) + else: + rng_keys = random.split(rng_key, self.num_particles) + return -jnp.mean(vmap(single_particle_elbo)(rng_keys)) diff --git a/numpyro/util.py b/numpyro/util.py index 9f5cc19c0..1b402e48b 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -516,6 +516,15 @@ def model(*args, **kwargs): def _validate_model(model_trace, plate_warning="loose"): # TODO: Consider exposing global configuration for those strategies. assert plate_warning in ["loose", "strict", "error"] + enum_dims = set( + [ + site["infer"]["name_to_dim"][name] + for name, site in model_trace.items() + if site["type"] == "sample" + and site["infer"].get("enumerate") == "parallel" + and site["infer"].get("name_to_dim") is not None + ] + ) # Check if plate is missing in the model. for name, site in model_trace.items(): if site["type"] == "sample": @@ -528,7 +537,7 @@ def _validate_model(model_trace, plate_warning="loose"): batch_ndim = len(batch_shape) for i in range(batch_ndim): dim = -i - 1 - if batch_shape[dim] > 1 and (dim not in plate_dims): + if batch_shape[dim] > 1 and (dim not in (plate_dims | enum_dims)): # Skip checking if it is the `scan` dimension. if dim == -batch_ndim and site.get("_control_flow_done", False): continue @@ -576,10 +585,22 @@ def check_model_guide_match(model_trace, guide_trace): model_vars = set( name for name, site in model_trace.items() - if site["type"] == "sample" and not site.get("is_observed", False) + if site["type"] == "sample" + and not site.get("is_observed", False) + and not ( + name not in guide_trace and site["infer"].get("enumerate") == "parallel" + ) + ) + enum_vars = set( + [ + name + for name, site in model_trace.items() + if site["type"] == "sample" + and not site.get("is_observed", False) + and name not in guide_trace + and site["infer"].get("enumerate") == "parallel" + ] ) - # TODO: Collect enum variables when TraceEnum_ELBO is supported. - enum_vars = set() if aux_vars & model_vars: warnings.warn( diff --git a/test/contrib/test_enum_elbo.py b/test/contrib/test_enum_elbo.py new file mode 100644 index 000000000..1ccfd602a --- /dev/null +++ b/test/contrib/test_enum_elbo.py @@ -0,0 +1,2257 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import logging + +import numpy as np +import pytest + +import jax +from jax import random +import jax.numpy as jnp + +import numpyro as pyro +from numpyro import handlers, infer +import numpyro.distributions as dist +from numpyro.distributions import constraints +from numpyro.ops.indexing import Vindex + +# put all funsor-related imports here, so test collection works without funsor +try: + import funsor + import numpyro.contrib.funsor + from numpyro.contrib.funsor import config_enumerate + + funsor.set_backend("jax") +except ImportError: + pytestmark = pytest.mark.skip(reason="funsor is not installed") + +logger = logging.getLogger(__name__) + +transform = dist.biject_to(dist.constraints.simplex) + + +def assert_equal(a, b, prec=0): + return jax.tree_util.tree_map( + lambda a, b: np.testing.assert_allclose(a, b, atol=prec), a, b + ) + + +def xfail_param(*args, **kwargs): + kwargs.setdefault("reason", "unknown") + return pytest.param(*args, marks=[pytest.mark.xfail(**kwargs)]) + + +@pytest.mark.parametrize("inner_dim", [2]) +@pytest.mark.parametrize("outer_dim", [2]) +def test_elbo_plate_plate(outer_dim, inner_dim): + q = jnp.array([0.75, 0.25]) + p = 0.2693204236205713 # for which kl(Categorical(q), Categorical(p)) = 0.5 + p = jnp.array([p, 1 - p]) + + def model(q): + d = dist.Categorical(p) + context1 = pyro.plate("outer", outer_dim, dim=-1) + context2 = pyro.plate("inner", inner_dim, dim=-2) + pyro.sample("w", d) + with context1: + pyro.sample("x", d) + with context2: + pyro.sample("y", d) + with context1, context2: + pyro.sample("z", d) + + def guide(q): + d = dist.Categorical(q) + context1 = pyro.plate("outer", outer_dim, dim=-1) + context2 = pyro.plate("inner", inner_dim, dim=-2) + pyro.sample("w", d, infer={"enumerate": "parallel"}) + with context1: + pyro.sample("x", d, infer={"enumerate": "parallel"}) + with context2: + pyro.sample("y", d, infer={"enumerate": "parallel"}) + with context1, context2: + pyro.sample("z", d, infer={"enumerate": "parallel"}) + + def expected_loss_fn(q): + kl_node = pyro.distributions.kl_divergence( + dist.Categorical(q), dist.Categorical(p) + ) + kl = (1 + outer_dim + inner_dim + outer_dim * inner_dim) * kl_node + return kl + + expected_loss, expected_grad = jax.value_and_grad(expected_loss_fn)(q) + + def actual_loss_fn(q): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=2) + return elbo.loss(random.PRNGKey(0), {}, model, guide, q) + + actual_loss, actual_grad = jax.value_and_grad(actual_loss_fn)(q) + + assert_equal(actual_loss, expected_loss, prec=1e-5) + assert_equal(actual_grad, expected_grad, prec=1e-5) + + +@pytest.mark.parametrize("scale", [1, 10]) +def test_elbo_enumerate_1(scale): + params = {} + params["guide_probs_x"] = jnp.array([0.1, 0.9]) + params["model_probs_x"] = jnp.array([0.4, 0.6]) + params["model_probs_y"] = jnp.array([[0.75, 0.25], [0.55, 0.45]]) + params["model_probs_z"] = jnp.array([0.3, 0.7]) + + @handlers.scale(scale=scale) + def auto_model(params): + probs_x = params["model_probs_x"] + probs_y = params["model_probs_y"] + probs_z = params["model_probs_z"] + x = pyro.sample("x", dist.Categorical(probs_x)) + pyro.sample("y", dist.Categorical(probs_y[x]), infer={"enumerate": "parallel"}) + pyro.sample("z", dist.Categorical(probs_z), obs=jnp.array(0)) + + @handlers.scale(scale=scale) + def hand_model(params): + probs_x = params["model_probs_x"] + probs_z = params["model_probs_z"] + pyro.sample("x", dist.Categorical(probs_x)) + pyro.sample("z", dist.Categorical(probs_z), obs=jnp.array(0)) + + @handlers.scale(scale=scale) + def guide(params): + probs_x = params["guide_probs_x"] + pyro.sample("x", dist.Categorical(probs_x), infer={"enumerate": "parallel"}) + + def auto_loss_fn(params_raw): + params = jax.tree_util.tree_map(transform, params_raw) + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, params) + + def hand_loss_fn(params_raw): + params = jax.tree_util.tree_map(transform, params_raw) + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, params) + + params_raw = jax.tree_util.tree_map(transform.inv, params) + auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params_raw) + hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params_raw) + + assert_equal(auto_loss, hand_loss, prec=1e-5) + assert_equal(auto_grad, hand_grad, prec=1e-5) + + +@pytest.mark.parametrize("scale", [1, 10]) +def test_elbo_enumerate_2(scale): + params = {} + params["guide_probs_x"] = jnp.array([0.1, 0.9]) + params["model_probs_x"] = jnp.array([0.4, 0.6]) + params["model_probs_y"] = jnp.array([[0.75, 0.25], [0.55, 0.45]]) + params["model_probs_z"] = jnp.array([[0.3, 0.7], [0.2, 0.8]]) + + @handlers.scale(scale=scale) + def auto_model(params): + probs_x = params["model_probs_x"] + probs_y = params["model_probs_y"] + probs_z = params["model_probs_z"] + x = pyro.sample("x", dist.Categorical(probs_x)) + y = pyro.sample( + "y", dist.Categorical(probs_y[x]), infer={"enumerate": "parallel"} + ) + pyro.sample("z", dist.Categorical(probs_z[y]), obs=0) + + @handlers.scale(scale=scale) + def hand_model(params): + probs_x = params["model_probs_x"] + probs_y = params["model_probs_y"] + probs_z = params["model_probs_z"] + probs_yz = probs_y @ probs_z + x = pyro.sample("x", dist.Categorical(probs_x)) + pyro.sample("z", dist.Categorical(probs_yz[x]), obs=0) + + @config_enumerate + @handlers.scale(scale=scale) + def guide(params): + probs_x = params["guide_probs_x"] + pyro.sample("x", dist.Categorical(probs_x)) + + def auto_loss_fn(params_raw): + params = jax.tree_util.tree_map(transform, params_raw) + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, params) + + def hand_loss_fn(params_raw): + params = jax.tree_util.tree_map(transform, params_raw) + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, params) + + params_raw = jax.tree_util.tree_map(transform.inv, params) + auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params_raw) + hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params_raw) + + assert_equal(auto_loss, hand_loss, prec=1e-5) + assert_equal(auto_grad, hand_grad, prec=1e-5) + + +@pytest.mark.parametrize("scale", [1, 10]) +def test_elbo_enumerate_3(scale): + params = {} + params["guide_probs_x"] = jnp.array([0.1, 0.9]) + params["model_probs_x"] = jnp.array([0.4, 0.6]) + params["model_probs_y"] = jnp.array([[0.75, 0.25], [0.55, 0.45]]) + params["model_probs_z"] = jnp.array([[0.3, 0.7], [0.2, 0.8]]) + + def auto_model(params): + probs_x = params["model_probs_x"] + probs_y = params["model_probs_y"] + probs_z = params["model_probs_z"] + x = pyro.sample("x", dist.Categorical(probs_x)) + with handlers.scale(scale=scale): + y = pyro.sample( + "y", dist.Categorical(probs_y[x]), infer={"enumerate": "parallel"} + ) + pyro.sample("z", dist.Categorical(probs_z[y]), obs=0) + + def hand_model(params): + probs_x = params["model_probs_x"] + probs_y = params["model_probs_y"] + probs_z = params["model_probs_z"] + probs_yz = probs_y @ probs_z + x = pyro.sample("x", dist.Categorical(probs_x)) + with handlers.scale(scale=scale): + pyro.sample("z", dist.Categorical(probs_yz[x]), obs=0) + + @config_enumerate + def guide(params): + probs_x = params["guide_probs_x"] + pyro.sample("x", dist.Categorical(probs_x)) + + def auto_loss_fn(params_raw): + params = jax.tree_util.tree_map(transform, params_raw) + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, params) + + def hand_loss_fn(params_raw): + params = jax.tree_util.tree_map(transform, params_raw) + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, params) + + params_raw = jax.tree_util.tree_map(transform.inv, params) + auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params_raw) + hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params_raw) + + assert_equal(auto_loss, hand_loss, prec=1e-5) + assert_equal(auto_grad, hand_grad, prec=1e-5) + + +@pytest.mark.parametrize("scale", [1, 10]) +@pytest.mark.parametrize( + "num_samples,num_masked", [(2, 2), (3, 2)], ids=["batch", "masked"] +) +def test_elbo_enumerate_plate_1(num_samples, num_masked, scale): + # +---------+ + # x ----> y ----> z | + # | N | + # +---------+ + params = {} + params["guide_probs_x"] = jnp.array([0.1, 0.9]) + params["model_probs_x"] = jnp.array([0.4, 0.6]) + params["model_probs_y"] = jnp.array([[0.75, 0.25], [0.55, 0.45]]) + params["model_probs_z"] = jnp.array([[0.3, 0.7], [0.2, 0.8]]) + + def auto_model(data, params): + probs_x = params["model_probs_x"] + probs_y = params["model_probs_y"] + probs_z = params["model_probs_z"] + x = pyro.sample("x", dist.Categorical(probs_x)) + with handlers.scale(scale=scale): + y = pyro.sample( + "y", dist.Categorical(probs_y[x]), infer={"enumerate": "parallel"} + ) + if num_masked == num_samples: + with pyro.plate("data", len(data)): + pyro.sample("z", dist.Categorical(probs_z[y]), obs=data) + else: + with pyro.plate("data", len(data)): + with handlers.mask(mask=jnp.arange(num_samples) < num_masked): + pyro.sample("z", dist.Categorical(probs_z[y]), obs=data) + + def hand_model(data, params): + probs_x = params["model_probs_x"] + probs_y = params["model_probs_y"] + probs_z = params["model_probs_z"] + x = pyro.sample("x", dist.Categorical(probs_x)) + with handlers.scale(scale=scale): + y = pyro.sample( + "y", dist.Categorical(probs_y[x]), infer={"enumerate": "parallel"} + ) + for i in range(num_masked): + pyro.sample(f"z_{i}", dist.Categorical(probs_z[y]), obs=data[i]) + + @config_enumerate + def guide(data, params): + probs_x = params["guide_probs_x"] + pyro.sample("x", dist.Categorical(probs_x)) + + data = dist.Categorical(jnp.array([0.3, 0.7])).sample( + random.PRNGKey(0), (num_samples,) + ) + + def auto_loss_fn(params_raw): + params = jax.tree_util.tree_map(transform, params_raw) + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, data, params) + + def hand_loss_fn(params_raw): + params = jax.tree_util.tree_map(transform, params_raw) + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, data, params) + + params_raw = jax.tree_util.tree_map(transform.inv, params) + auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params_raw) + hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params_raw) + + assert_equal(auto_loss, hand_loss, prec=1e-5) + assert_equal(auto_grad, hand_grad, prec=1e-5) + + +@pytest.mark.parametrize("scale", [1, 10]) +@pytest.mark.parametrize( + "num_samples,num_masked", [(2, 2), (3, 2)], ids=["batch", "masked"] +) +def test_elbo_enumerate_plate_2(num_samples, num_masked, scale): + # +-----------------+ + # x ----> y ----> z | + # | N | + # +-----------------+ + params = {} + params["guide_probs_x"] = jnp.array([0.1, 0.9]) + params["model_probs_x"] = jnp.array([0.4, 0.6]) + params["model_probs_y"] = jnp.array([[0.75, 0.25], [0.55, 0.45]]) + params["model_probs_z"] = jnp.array([[0.3, 0.7], [0.2, 0.8]]) + + def auto_model(data, params): + probs_x = params["model_probs_x"] + probs_y = params["model_probs_y"] + probs_z = params["model_probs_z"] + x = pyro.sample("x", dist.Categorical(probs_x)) + with handlers.scale(scale=scale): + with pyro.plate("data", len(data)): + if num_masked == num_samples: + y = pyro.sample( + "y", + dist.Categorical(probs_y[x]), + infer={"enumerate": "parallel"}, + ) + pyro.sample("z", dist.Categorical(probs_z[y]), obs=data) + else: + with handlers.mask(mask=jnp.arange(num_samples) < num_masked): + y = pyro.sample( + "y", + dist.Categorical(probs_y[x]), + infer={"enumerate": "parallel"}, + ) + pyro.sample("z", dist.Categorical(probs_z[y]), obs=data) + + def hand_model(data, params): + probs_x = params["model_probs_x"] + probs_y = params["model_probs_y"] + probs_z = params["model_probs_z"] + x = pyro.sample("x", dist.Categorical(probs_x)) + with handlers.scale(scale=scale): + for i in range(num_masked): + y = pyro.sample( + f"y_{i}", + dist.Categorical(probs_y[x]), + infer={"enumerate": "parallel"}, + ) + pyro.sample(f"z_{i}", dist.Categorical(probs_z[y]), obs=data[i]) + + @config_enumerate + def guide(data, params): + probs_x = params["guide_probs_x"] + pyro.sample("x", dist.Categorical(probs_x)) + + data = dist.Categorical(jnp.array([0.3, 0.7])).sample( + random.PRNGKey(0), (num_samples,) + ) + + def auto_loss_fn(params_raw): + params = jax.tree_util.tree_map(transform, params_raw) + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, data, params) + + def hand_loss_fn(params_raw): + params = jax.tree_util.tree_map(transform, params_raw) + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, data, params) + + params_raw = jax.tree_util.tree_map(transform.inv, params) + auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params_raw) + hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params_raw) + + assert_equal(auto_loss, hand_loss, prec=1e-5) + assert_equal(auto_grad, hand_grad, prec=1e-5) + + +@pytest.mark.parametrize("scale", [1, 10]) +@pytest.mark.parametrize( + "num_samples,num_masked", [(2, 2), (3, 2)], ids=["batch", "masked"] +) +def test_elbo_enumerate_plate_3(num_samples, num_masked, scale): + # +-----------------------+ + # | x ----> y ----> z | + # | N | + # +-----------------------+ + # This plate should remain unreduced since all enumeration is in a single plate. + params = {} + params["guide_probs_x"] = jnp.array([0.1, 0.9]) + params["model_probs_x"] = jnp.array([0.4, 0.6]) + params["model_probs_y"] = jnp.array([[0.75, 0.25], [0.55, 0.45]]) + params["model_probs_z"] = jnp.array([[0.3, 0.7], [0.2, 0.8]]) + + @handlers.scale(scale=scale) + def auto_model(data, params): + probs_x = params["model_probs_x"] + probs_y = params["model_probs_y"] + probs_z = params["model_probs_z"] + with pyro.plate("data", len(data)): + if num_masked == num_samples: + x = pyro.sample("x", dist.Categorical(probs_x)) + y = pyro.sample( + "y", dist.Categorical(probs_y[x]), infer={"enumerate": "parallel"} + ) + pyro.sample("z", dist.Categorical(probs_z[y]), obs=data) + else: + with handlers.mask(mask=jnp.arange(num_samples) < num_masked): + x = pyro.sample("x", dist.Categorical(probs_x)) + y = pyro.sample( + "y", + dist.Categorical(probs_y[x]), + infer={"enumerate": "parallel"}, + ) + pyro.sample("z", dist.Categorical(probs_z[y]), obs=data) + + @handlers.scale(scale=scale) + @config_enumerate + def auto_guide(data, params): + probs_x = params["guide_probs_x"] + with pyro.plate("data", len(data)): + if num_masked == num_samples: + pyro.sample("x", dist.Categorical(probs_x)) + else: + with handlers.mask(mask=jnp.arange(num_samples) < num_masked): + pyro.sample("x", dist.Categorical(probs_x)) + + @handlers.scale(scale=scale) + def hand_model(data, params): + probs_x = params["model_probs_x"] + probs_y = params["model_probs_y"] + probs_z = params["model_probs_z"] + for i in range(num_masked): + x = pyro.sample(f"x_{i}", dist.Categorical(probs_x)) + y = pyro.sample( + f"y_{i}", + dist.Categorical(probs_y[x]), + infer={"enumerate": "parallel"}, + ) + pyro.sample("z_{}".format(i), dist.Categorical(probs_z[y]), obs=data[i]) + + @handlers.scale(scale=scale) + @config_enumerate + def hand_guide(data, params): + probs_x = params["guide_probs_x"] + for i in range(num_masked): + pyro.sample(f"x_{i}", dist.Categorical(probs_x)) + + data = dist.Categorical(jnp.array([0.3, 0.7])).sample( + random.PRNGKey(0), (num_samples,) + ) + + def auto_loss_fn(params_raw): + params = jax.tree_util.tree_map(transform, params_raw) + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + return elbo.loss(random.PRNGKey(0), {}, auto_model, auto_guide, data, params) + + def hand_loss_fn(params_raw): + params = jax.tree_util.tree_map(transform, params_raw) + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + return elbo.loss(random.PRNGKey(0), {}, hand_model, hand_guide, data, params) + + params_raw = jax.tree_util.tree_map(transform.inv, params) + auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params_raw) + hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params_raw) + + assert_equal(auto_loss, hand_loss, prec=1e-5) + assert_equal(auto_grad, hand_grad, prec=1e-5) + + +@pytest.mark.parametrize("scale", [1, 10]) +@pytest.mark.parametrize( + "outer_obs,inner_obs", [(False, True), (True, False), (True, True)] +) +def test_elbo_enumerate_plate_4(outer_obs, inner_obs, scale): + # a ---> outer_obs + # \ + # +-----\------------------+ + # | \ | + # | b ---> inner_obs N=2 | + # +------------------------+ + # This tests two different observations, one outside and one inside an plate. + params = {} + params["probs_a"] = jnp.array([0.4, 0.6]) + params["probs_b"] = jnp.array([0.6, 0.4]) + params["locs"] = jnp.array([-1.0, 1.0]) + params["scales"] = jnp.array([1.0, 2.0]) + + outer_data = 2.0 + inner_data = jnp.array([0.5, 1.5]) + + @handlers.scale(scale=scale) + def auto_model(params): + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["probs_b"], constraint=constraints.simplex + ) + locs = pyro.param("locs", params["locs"]) + scales = pyro.param("scales", params["scales"], constraint=constraints.positive) + a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) + if outer_obs: + pyro.sample("outer_obs", dist.Normal(0.0, scales[a]), obs=outer_data) + with pyro.plate("inner", 2): + b = pyro.sample( + "b", dist.Categorical(probs_b), infer={"enumerate": "parallel"} + ) + if inner_obs: + pyro.sample( + "inner_obs", dist.Normal(locs[b], scales[a]), obs=inner_data + ) + + @handlers.scale(scale=scale) + def hand_model(params): + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["probs_b"], constraint=constraints.simplex + ) + locs = pyro.param("locs", params["locs"]) + scales = pyro.param("scales", params["scales"], constraint=constraints.positive) + a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) + if outer_obs: + pyro.sample("outer_obs", dist.Normal(0.0, scales[a]), obs=outer_data) + for i in range(2): + b = pyro.sample( + f"b_{i}", + dist.Categorical(probs_b), + infer={"enumerate": "parallel"}, + ) + if inner_obs: + pyro.sample( + f"inner_obs_{i}", + dist.Normal(locs[b], scales[a]), + obs=inner_data[i], + ) + + def guide(params): + pass + + def auto_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, params) + + def hand_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, params) + + auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params) + hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params) + + assert_equal(auto_loss, hand_loss, prec=1e-5) + assert_equal(auto_grad, hand_grad, prec=1e-5) + + +def test_elbo_enumerate_plate_5(): + # Guide Model + # a + # +---------------|--+ + # | M=2 V | + # | b ----> c | + # +------------------+ + params = {} + params["model_probs_a"] = jnp.array([0.45, 0.55]) + params["model_probs_b"] = jnp.array([0.6, 0.4]) + params["model_probs_c"] = jnp.array( + [[[0.4, 0.5, 0.1], [0.3, 0.5, 0.2]], [[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]] + ) + params["guide_probs_b"] = jnp.array([0.8, 0.2]) + data = jnp.array([1, 2]) + + @config_enumerate + def model_plate(params): + probs_a = pyro.param( + "model_probs_a", params["model_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "model_probs_b", params["model_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "model_probs_c", + params["model_probs_c"], + constraint=constraints.simplex, + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("b_axis", 2): + b = pyro.sample("b", dist.Categorical(probs_b)) + pyro.sample("c", dist.Categorical(Vindex(probs_c)[a, b]), obs=data) + + @config_enumerate + def guide_plate(params): + probs_b = pyro.param( + "guide_probs_b", params["guide_probs_b"], constraint=constraints.simplex + ) + with pyro.plate("b_axis", 2): + pyro.sample("b", dist.Categorical(probs_b)) + + @config_enumerate + def model_iplate(params): + probs_a = pyro.param( + "model_probs_a", params["model_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "model_probs_b", params["model_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "model_probs_c", + params["model_probs_c"], + constraint=constraints.simplex, + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + for i in range(2): + b = pyro.sample(f"b_{i}", dist.Categorical(probs_b)) + pyro.sample(f"c_{i}", dist.Categorical(Vindex(probs_c)[a, b]), obs=data[i]) + + @config_enumerate + def guide_iplate(params): + probs_b = pyro.param( + "guide_probs_b", params["guide_probs_b"], constraint=constraints.simplex + ) + for i in range(2): + pyro.sample(f"b_{i}", dist.Categorical(probs_b)) + + def auto_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + return elbo.loss(random.PRNGKey(0), {}, model_plate, guide_plate, params) + + def hand_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss(random.PRNGKey(0), {}, model_iplate, guide_iplate, params) + + with pytest.raises( + ValueError, match="Expected model enumeration to be no more global than guide" + ): + hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params) + # This never gets run because we don't support this yet. + auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params) + + assert_equal(auto_loss, hand_loss, prec=1e-5) + assert_equal(auto_grad, hand_grad, prec=1e-5) + + +@pytest.mark.parametrize("enumerate1", ["parallel", "sequential"]) +def test_elbo_enumerate_plate_6(enumerate1): + # Guide Model + # +-------+ + # b ----> c <---- a + # | M=2 | + # +-------+ + # This tests that sequential enumeration over b works, even though + # model-side enumeration moves c into b's plate via contraction. + params = {} + params["model_probs_a"] = jnp.array([0.45, 0.55]) + params["model_probs_b"] = jnp.array([0.6, 0.4]) + params["model_probs_c"] = jnp.array( + [[[0.4, 0.5, 0.1], [0.3, 0.5, 0.2]], [[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]] + ) + params["guide_probs_b"] = jnp.array([0.8, 0.2]) + data = jnp.array([1, 2]) + + @config_enumerate + def model_plate(params): + probs_a = pyro.param( + "model_probs_a", params["model_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "model_probs_b", params["model_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "model_probs_c", + params["model_probs_c"], + constraint=constraints.simplex, + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + b = pyro.sample("b", dist.Categorical(probs_b)) + with pyro.plate("b_axis", 2): + pyro.sample("c", dist.Categorical(Vindex(probs_c)[a, b]), obs=data) + + @config_enumerate + def model_iplate(params): + probs_a = pyro.param( + "model_probs_a", params["model_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "model_probs_b", params["model_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "model_probs_c", + params["model_probs_c"], + constraint=constraints.simplex, + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + b = pyro.sample("b", dist.Categorical(probs_b)) + for i in range(2): + pyro.sample( + "c_{}".format(i), dist.Categorical(Vindex(probs_c)[a, b]), obs=data[i] + ) + + @config_enumerate(default=enumerate1) + def guide(params): + probs_b = pyro.param( + "guide_probs_b", params["guide_probs_b"], constraint=constraints.simplex + ) + pyro.sample("b", dist.Categorical(probs_b)) + + def auto_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + return elbo.loss(random.PRNGKey(0), {}, model_plate, guide, params) + + def hand_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss(random.PRNGKey(0), {}, model_iplate, guide, params) + + auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params) + hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params) + + assert_equal(auto_loss, hand_loss, prec=1e-5) + assert_equal(auto_grad, hand_grad, prec=1e-5) + + +@pytest.mark.parametrize("scale", [1, 10]) +def test_elbo_enumerate_plate_7(scale): + # Guide Model + # a -----> b + # | | + # +-|--------|----------------+ + # | V V | + # | c -----> d -----> e N=2 | + # +---------------------------+ + # This tests a mixture of model and guide enumeration. + params = {} + params["model_probs_a"] = jnp.array([0.45, 0.55]) + params["model_probs_b"] = jnp.array([[0.6, 0.4], [0.4, 0.6]]) + params["model_probs_c"] = jnp.array([[0.75, 0.25], [0.55, 0.45]]) + params["model_probs_d"] = jnp.array( + [[[0.4, 0.6], [0.3, 0.7]], [[0.3, 0.7], [0.2, 0.8]]] + ) + params["model_probs_e"] = jnp.array([[0.75, 0.25], [0.55, 0.45]]) + params["guide_probs_a"] = jnp.array([0.35, 0.64]) + params["guide_probs_c"] = jnp.array([[0.0, 1.0], [1.0, 0.0]]) # deterministic + + @handlers.scale(scale=scale) + def auto_model(data, params): + probs_a = pyro.param( + "model_probs_a", params["model_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "model_probs_b", params["model_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "model_probs_c", + params["model_probs_c"], + constraint=constraints.simplex, + ) + probs_d = pyro.param( + "model_probs_d", + params["model_probs_d"], + constraint=constraints.simplex, + ) + probs_e = pyro.param( + "model_probs_e", + params["model_probs_e"], + constraint=constraints.simplex, + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + b = pyro.sample( + "b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"} + ) + with pyro.plate("data", 2): + c = pyro.sample("c", dist.Categorical(probs_c[a])) + d = pyro.sample( + "d", + dist.Categorical(Vindex(probs_d)[b, c]), + infer={"enumerate": "parallel"}, + ) + pyro.sample("obs", dist.Categorical(probs_e[d]), obs=data) + + @handlers.scale(scale=scale) + def auto_guide(data, params): + probs_a = pyro.param( + "guide_probs_a", params["guide_probs_a"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "guide_probs_c", params["guide_probs_c"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) + with pyro.plate("data", 2): + pyro.sample("c", dist.Categorical(probs_c[a])) + + @handlers.scale(scale=scale) + def hand_model(data, params): + probs_a = pyro.param( + "model_probs_a", params["model_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "model_probs_b", params["model_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "model_probs_c", + params["model_probs_c"], + constraint=constraints.simplex, + ) + probs_d = pyro.param( + "model_probs_d", + params["model_probs_d"], + constraint=constraints.simplex, + ) + probs_e = pyro.param( + "model_probs_e", + params["model_probs_e"], + constraint=constraints.simplex, + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + b = pyro.sample( + "b", dist.Categorical(probs_b[a]), infer={"enumerate": "parallel"} + ) + for i in range(2): + c = pyro.sample(f"c_{i}", dist.Categorical(probs_c[a])) + d = pyro.sample( + f"d_{i}", + dist.Categorical(Vindex(probs_d)[b, c]), + infer={"enumerate": "parallel"}, + ) + pyro.sample(f"obs_{i}", dist.Categorical(probs_e[d]), obs=data[i]) + + @handlers.scale(scale=scale) + def hand_guide(data, params): + probs_a = pyro.param( + "guide_probs_a", params["guide_probs_a"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "guide_probs_c", params["guide_probs_c"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a), infer={"enumerate": "parallel"}) + for i in range(2): + pyro.sample(f"c_{i}", dist.Categorical(probs_c[a])) + + data = jnp.array([0, 0]) + + def auto_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + return elbo.loss(random.PRNGKey(0), {}, auto_model, auto_guide, data, params) + + def hand_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss(random.PRNGKey(0), {}, hand_model, hand_guide, data, params) + + auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params) + hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params) + + assert_equal(auto_loss, hand_loss, prec=1e-5) + assert_equal(auto_grad, hand_grad, prec=1e-5) + + +@pytest.mark.parametrize("scale", [1, 10]) +def test_elbo_enumerate_plates_1(scale): + # +-----------------+ + # | a ----> b M=2 | + # +-----------------+ + # +-----------------+ + # | c ----> d N=3 | + # +-----------------+ + # This tests two unrelated plates. + # Each should remain uncontracted. + params = {} + params["probs_a"] = jnp.array([0.45, 0.55]) + params["probs_b"] = jnp.array([[0.6, 0.4], [0.4, 0.6]]) + params["probs_c"] = jnp.array([0.75, 0.25]) + params["probs_d"] = jnp.array([[0.4, 0.6], [0.3, 0.7]]) + + b_data = jnp.array([0, 1]) + d_data = jnp.array([0, 0, 1]) + + @config_enumerate + @handlers.scale(scale=scale) + def auto_model(params): + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["probs_c"], constraint=constraints.simplex + ) + probs_d = pyro.param( + "probs_d", params["probs_d"], constraint=constraints.simplex + ) + with pyro.plate("a_axis", 2): + a = pyro.sample("a", dist.Categorical(probs_a)) + pyro.sample("b", dist.Categorical(probs_b[a]), obs=b_data) + with pyro.plate("c_axis", 3): + c = pyro.sample("c", dist.Categorical(probs_c)) + pyro.sample("d", dist.Categorical(probs_d[c]), obs=d_data) + + @config_enumerate + @handlers.scale(scale=scale) + def hand_model(params): + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["probs_c"], constraint=constraints.simplex + ) + probs_d = pyro.param( + "probs_d", params["probs_d"], constraint=constraints.simplex + ) + for i in range(2): + a = pyro.sample(f"a_{i}", dist.Categorical(probs_a)) + pyro.sample(f"b_{i}", dist.Categorical(probs_b[a]), obs=b_data[i]) + for j in range(3): + c = pyro.sample(f"c_{j}", dist.Categorical(probs_c)) + pyro.sample(f"d_{j}", dist.Categorical(probs_d[c]), obs=d_data[j]) + + def guide(params): + pass + + def auto_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, params) + + def hand_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, params) + + auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params) + hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params) + + assert_equal(auto_loss, hand_loss, prec=1e-5) + assert_equal(auto_grad, hand_grad, prec=1e-5) + + +@pytest.mark.parametrize("scale", [1, 10]) +def test_elbo_enumerate_plates_2(scale): + # +---------+ +---------+ + # | b <---- a ----> c | + # | M=2 | | N=3 | + # +---------+ +---------+ + # This tests two different plates with recycled dimension. + params = {} + params["probs_a"] = jnp.array([0.45, 0.55]) + params["probs_b"] = jnp.array([[0.6, 0.4], [0.4, 0.6]]) + params["probs_c"] = jnp.array([[0.75, 0.25], [0.55, 0.45]]) + + b_data = jnp.array([0, 1]) + c_data = jnp.array([0, 0, 1]) + + @config_enumerate + @handlers.scale(scale=scale) + def auto_model(params): + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["probs_c"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("b_axis", 2): + pyro.sample("b", dist.Categorical(probs_b[a]), obs=b_data) + with pyro.plate("c_axis", 3): + pyro.sample("c", dist.Categorical(probs_c[a]), obs=c_data) + + @config_enumerate + @handlers.scale(scale=scale) + def hand_model(params): + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["probs_c"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + for i in range(2): + pyro.sample(f"b_{i}", dist.Categorical(probs_b[a]), obs=b_data[i]) + for j in range(3): + pyro.sample(f"c_{j}", dist.Categorical(probs_c[a]), obs=c_data[j]) + + def guide(params): + pass + + def auto_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, params) + + def hand_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, params) + + auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params) + hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params) + + assert_equal(auto_loss, hand_loss, prec=1e-5) + assert_equal(auto_grad, hand_grad, prec=1e-5) + + +@pytest.mark.parametrize("scale", [1, 10]) +def test_elbo_enumerate_plates_3(scale): + # +--------------------+ + # | +----------+ | + # a -------> b | | + # | | N=2 | | + # | +----------+ M=2 | + # +--------------------+ + # This is tests the case of multiple plate contractions in + # a single step. + params = {} + params["probs_a"] = jnp.array([0.45, 0.55]) + params["probs_b"] = jnp.array([[0.6, 0.4], [0.4, 0.6]]) + data = jnp.array([[0, 1], [0, 0]]) + + @config_enumerate + @handlers.scale(scale=scale) + def auto_model(params): + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["probs_b"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("outer", 2): + with pyro.plate("inner", 2): + pyro.sample("b", dist.Categorical(probs_b[a]), obs=data) + + @config_enumerate + @handlers.scale(scale=scale) + def hand_model(params): + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["probs_b"], constraint=constraints.simplex + ) + inner = range(2) + a = pyro.sample("a", dist.Categorical(probs_a)) + for i in range(2): + for j in inner: + pyro.sample(f"b_{i}_{j}", dist.Categorical(probs_b[a]), obs=data[i, j]) + + def guide(params): + pass + + def auto_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=2) + return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, params) + + def hand_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, params) + + auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params) + hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params) + + assert_equal(auto_loss, hand_loss, prec=1e-5) + assert_equal(auto_grad, hand_grad, prec=1e-5) + + +@pytest.mark.parametrize("scale", [1, 10]) +def test_elbo_enumerate_plates_4(scale): + # +--------------------+ + # | +----------+ | + # a ----> b ----> c | | + # | | N=2 | | + # | M=2 +----------+ | + # +--------------------+ + params = {} + params["probs_a"] = jnp.array([0.45, 0.55]) + params["probs_b"] = jnp.array([[0.6, 0.4], [0.4, 0.6]]) + params["probs_c"] = jnp.array([[0.4, 0.6], [0.3, 0.7]]) + + @config_enumerate + @handlers.scale(scale=scale) + def auto_model(data, params): + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["probs_c"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("outer", 2): + b = pyro.sample("b", dist.Categorical(probs_b[a])) + with pyro.plate("inner", 2): + pyro.sample("c", dist.Categorical(probs_c[b]), obs=data) + + @config_enumerate + @handlers.scale(scale=scale) + def hand_model(data, params): + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["probs_c"], constraint=constraints.simplex + ) + inner = range(2) + a = pyro.sample("a", dist.Categorical(probs_a)) + for i in range(2): + b = pyro.sample(f"b_{i}", dist.Categorical(probs_b[a])) + for j in inner: + pyro.sample(f"c_{i}_{j}", dist.Categorical(probs_c[b]), obs=data[i, j]) + + def guide(data, params): + pass + + data = jnp.array([[0, 1], [0, 0]]) + + def auto_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=2) + return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, data, params) + + def hand_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, data, params) + + auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params) + hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params) + + assert_equal(auto_loss, hand_loss, prec=1e-5) + assert_equal(auto_grad, hand_grad, prec=1e-5) + + +@pytest.mark.parametrize("scale", [1, 10]) +def test_elbo_enumerate_plates_5(scale): + # a + # | \ + # +--|---\------------+ + # | V +-\--------+ | + # | b ----> c | | + # | | N=2 | | + # | M=2 +----------+ | + # +-------------------+ + params = {} + params["probs_a"] = jnp.array([0.45, 0.55]) + params["probs_b"] = jnp.array([[0.6, 0.4], [0.4, 0.6]]) + params["probs_c"] = jnp.array([[[0.4, 0.6], [0.3, 0.7]], [[0.2, 0.8], [0.1, 0.9]]]) + data = jnp.array([[0, 1], [0, 0]]) + + @config_enumerate + @handlers.scale(scale=scale) + def auto_model(params): + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["probs_c"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("outer", 2): + b = pyro.sample("b", dist.Categorical(probs_b[a])) + with pyro.plate("inner", 2): + pyro.sample("c", dist.Categorical(Vindex(probs_c)[a, b]), obs=data) + + @config_enumerate + @handlers.scale(scale=scale) + def hand_model(params): + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["probs_c"], constraint=constraints.simplex + ) + inner = range(2) + a = pyro.sample("a", dist.Categorical(probs_a)) + for i in range(2): + b = pyro.sample(f"b_{i}", dist.Categorical(probs_b[a])) + for j in inner: + pyro.sample( + f"c_{i}_{j}", + dist.Categorical(Vindex(probs_c)[a, b]), + obs=data[i, j], + ) + + def guide(params): + pass + + def auto_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=2) + return elbo.loss(random.PRNGKey(0), {}, auto_model, guide, params) + + def hand_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss(random.PRNGKey(0), {}, hand_model, guide, params) + + auto_loss, auto_grad = jax.value_and_grad(auto_loss_fn)(params) + hand_loss, hand_grad = jax.value_and_grad(hand_loss_fn)(params) + + assert_equal(auto_loss, hand_loss, prec=1e-5) + assert_equal(auto_grad, hand_grad, prec=1e-5) + + +@pytest.mark.parametrize("scale", [1, 10]) +def test_elbo_enumerate_plates_6(scale): + # +----------+ + # | M=2 | + # a ----> b | + # | | | | + # +--|-------|--+ | + # | V | V | | + # | c ----> d | | + # | | | | + # | N=2 +------|---+ + # +-------------+ + # This tests different ways of mixing two independence contexts, + # where each can be either sequential or vectorized plate. + params = {} + params["probs_a"] = jnp.array([0.45, 0.55]) + params["probs_b"] = jnp.array([[0.6, 0.4], [0.4, 0.6]]) + params["probs_c"] = jnp.array([[0.75, 0.25], [0.55, 0.45]]) + params["probs_d"] = jnp.array([[[0.4, 0.6], [0.3, 0.7]], [[0.3, 0.7], [0.2, 0.8]]]) + + @config_enumerate + @handlers.scale(scale=scale) + @handlers.trace + def model_iplate_iplate(data, params): + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["probs_c"], constraint=constraints.simplex + ) + probs_d = pyro.param( + "probs_d", params["probs_d"], constraint=constraints.simplex + ) + b_axis = range(2) + c_axis = range(2) + a = pyro.sample("a", dist.Categorical(probs_a)) + b = [pyro.sample(f"b_{i}", dist.Categorical(probs_b[a])) for i in b_axis] + c = [pyro.sample(f"c_{j}", dist.Categorical(probs_c[a])) for j in c_axis] + for i in b_axis: + for j in c_axis: + pyro.sample( + f"d_{i}_{j}", + dist.Categorical(Vindex(probs_d)[b[i], c[j]]), + obs=data[i, j], + ) + + @config_enumerate + @handlers.scale(scale=scale) + def model_iplate_plate(data, params): + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["probs_c"], constraint=constraints.simplex + ) + probs_d = pyro.param( + "probs_d", params["probs_d"], constraint=constraints.simplex + ) + b_axis = range(2) + c_axis = pyro.plate("c_axis", 2) + a = pyro.sample("a", dist.Categorical(probs_a)) + with c_axis: + c = pyro.sample("c", dist.Categorical(probs_c[a])) + for i in b_axis: + b_i = pyro.sample(f"b_{i}", dist.Categorical(probs_b[a])) + with c_axis: + pyro.sample( + f"d_{i}", + dist.Categorical(Vindex(probs_d)[b_i, c]), + obs=data[i], + ) + + @config_enumerate + @handlers.scale(scale=scale) + @handlers.trace + def model_plate_iplate(data, params): + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["probs_c"], constraint=constraints.simplex + ) + probs_d = pyro.param( + "probs_d", params["probs_d"], constraint=constraints.simplex + ) + b_axis = pyro.plate("b_axis", 2) + c_axis = range(2) + a = pyro.sample("a", dist.Categorical(probs_a)) + with b_axis: + b = pyro.sample("b", dist.Categorical(probs_b[a])) + c = [pyro.sample(f"c_{j}", dist.Categorical(probs_c[a])) for j in c_axis] + with b_axis: + for j in c_axis: + pyro.sample( + f"d_{j}", + dist.Categorical(Vindex(probs_d)[b, c[j]]), + obs=data[:, j], + ) + + @config_enumerate + @handlers.scale(scale=scale) + def model_plate_plate(data, params): + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["probs_c"], constraint=constraints.simplex + ) + probs_d = pyro.param( + "probs_d", params["probs_d"], constraint=constraints.simplex + ) + b_axis = pyro.plate("b_axis", 2, dim=-1) + c_axis = pyro.plate("c_axis", 2, dim=-2) + a = pyro.sample("a", dist.Categorical(probs_a)) + with b_axis: + b = pyro.sample("b", dist.Categorical(probs_b[a])) + with c_axis: + c = pyro.sample("c", dist.Categorical(probs_c[a])) + with b_axis, c_axis: + pyro.sample("d", dist.Categorical(Vindex(probs_d)[b, c]), obs=data) + + def guide(data, params): + pass + + # Check that either one of the sequential plates can be promoted to be vectorized. + data = jnp.array([[0, 1], [0, 0]]) + + def iplate_iplate_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss( + random.PRNGKey(0), {}, model_iplate_iplate, guide, data, params + ) + + def plate_iplate_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + return elbo.loss(random.PRNGKey(0), {}, model_plate_iplate, guide, data, params) + + def iplate_plate_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + return elbo.loss(random.PRNGKey(0), {}, model_iplate_plate, guide, data, params) + + iplate_iplate_loss, iplate_iplate_grad = jax.value_and_grad(iplate_iplate_loss_fn)( + params + ) + plate_iplate_loss, plate_iplate_grad = jax.value_and_grad(plate_iplate_loss_fn)( + params + ) + iplate_plate_loss, iplate_plate_grad = jax.value_and_grad(iplate_plate_loss_fn)( + params + ) + + assert_equal(iplate_iplate_loss, plate_iplate_loss, prec=1e-5) + assert_equal(iplate_iplate_grad, plate_iplate_grad, prec=1e-5) + assert_equal(iplate_iplate_loss, iplate_plate_loss, prec=1e-5) + assert_equal(iplate_iplate_grad, iplate_plate_grad, prec=1e-5) + + # But promoting both to plates should result in an error. + with pytest.raises(ValueError, match="intractable!"): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=2) + elbo.loss(random.PRNGKey(0), {}, model_plate_plate, guide, data, params) + + +@pytest.mark.parametrize("scale", [1, 10]) +def test_elbo_enumerate_plates_7(scale): + # +-------------+ + # | N=2 | + # a -------> c | + # | | | | + # +--|----------|--+ | + # | | | V | | + # | V | e | | + # | b ----> d | | + # | | | | + # | M=2 +---------|---+ + # +----------------+ + # This tests tree-structured dependencies among variables but + # non-tree dependencies among plate nestings. + params = {} + params["probs_a"] = jnp.array([0.45, 0.55]) + params["probs_b"] = jnp.array([[0.6, 0.4], [0.4, 0.6]]) + params["probs_c"] = jnp.array([[0.75, 0.25], [0.55, 0.45]]) + params["probs_d"] = jnp.array([[0.3, 0.7], [0.2, 0.8]]) + params["probs_e"] = jnp.array([[0.4, 0.6], [0.3, 0.7]]) + + @config_enumerate + @handlers.scale(scale=scale) + @handlers.trace + def model_iplate_iplate(data, params): + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["probs_c"], constraint=constraints.simplex + ) + probs_d = pyro.param( + "probs_d", params["probs_d"], constraint=constraints.simplex + ) + probs_e = pyro.param( + "probs_e", params["probs_e"], constraint=constraints.simplex + ) + b_axis = range(2) + c_axis = range(2) + a = pyro.sample("a", dist.Categorical(probs_a)) + b = [pyro.sample(f"b_{i}", dist.Categorical(probs_b[a])) for i in b_axis] + c = [pyro.sample(f"c_{j}", dist.Categorical(probs_c[a])) for j in c_axis] + for i in b_axis: + for j in c_axis: + pyro.sample( + f"d_{i}_{j}", + dist.Categorical(probs_d[b[i]]), + obs=data[i, j], + ) + pyro.sample( + f"e_{i}_{j}", + dist.Categorical(probs_e[c[j]]), + obs=data[i, j], + ) + + @config_enumerate + @handlers.scale(scale=scale) + def model_iplate_plate(data, params): + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["probs_c"], constraint=constraints.simplex + ) + probs_d = pyro.param( + "probs_d", params["probs_d"], constraint=constraints.simplex + ) + probs_e = pyro.param( + "probs_e", params["probs_e"], constraint=constraints.simplex + ) + b_axis = range(2) + c_axis = pyro.plate("c_axis", 2) + a = pyro.sample("a", dist.Categorical(probs_a)) + with c_axis: + c = pyro.sample("c", dist.Categorical(probs_c[a])) + for i in b_axis: + b_i = pyro.sample(f"b_{i}", dist.Categorical(probs_b[a])) + with c_axis: + pyro.sample(f"d_{i}", dist.Categorical(probs_d[b_i]), obs=data[i]) + pyro.sample(f"e_{i}", dist.Categorical(probs_e[c]), obs=data[i]) + + @config_enumerate + @handlers.scale(scale=scale) + @handlers.trace + def model_plate_iplate(data, params): + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["probs_c"], constraint=constraints.simplex + ) + probs_d = pyro.param( + "probs_d", params["probs_d"], constraint=constraints.simplex + ) + probs_e = pyro.param( + "probs_e", params["probs_e"], constraint=constraints.simplex + ) + b_axis = pyro.plate("b_axis", 2) + c_axis = range(2) + a = pyro.sample("a", dist.Categorical(probs_a)) + with b_axis: + b = pyro.sample("b", dist.Categorical(probs_b[a])) + c = [pyro.sample(f"c_{j}", dist.Categorical(probs_c[a])) for j in c_axis] + with b_axis: + for j in c_axis: + pyro.sample(f"d_{j}", dist.Categorical(probs_d[b]), obs=data[:, j]) + pyro.sample(f"e_{j}", dist.Categorical(probs_e[c[j]]), obs=data[:, j]) + + @config_enumerate + @handlers.scale(scale=scale) + def model_plate_plate(data, params): + probs_a = pyro.param( + "probs_a", params["probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["probs_c"], constraint=constraints.simplex + ) + probs_d = pyro.param( + "probs_d", params["probs_d"], constraint=constraints.simplex + ) + probs_e = pyro.param( + "probs_e", params["probs_e"], constraint=constraints.simplex + ) + b_axis = pyro.plate("b_axis", 2, dim=-1) + c_axis = pyro.plate("c_axis", 2, dim=-2) + a = pyro.sample("a", dist.Categorical(probs_a)) + with b_axis: + b = pyro.sample("b", dist.Categorical(probs_b[a])) + with c_axis: + c = pyro.sample("c", dist.Categorical(probs_c[a])) + with b_axis, c_axis: + pyro.sample("d", dist.Categorical(probs_d[b]), obs=data) + pyro.sample("e", dist.Categorical(probs_e[c]), obs=data) + + def guide(data, params): + pass + + # Check that any combination of sequential plates can be promoted to be vectorized. + data = jnp.array([[0, 1], [0, 0]]) + + def iplate_iplate_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss( + random.PRNGKey(0), {}, model_iplate_iplate, guide, data, params + ) + + def plate_iplate_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + return elbo.loss(random.PRNGKey(0), {}, model_plate_iplate, guide, data, params) + + def iplate_plate_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + return elbo.loss(random.PRNGKey(0), {}, model_iplate_plate, guide, data, params) + + def plate_plate_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=2) + return elbo.loss(random.PRNGKey(0), {}, model_plate_plate, guide, data, params) + + iplate_iplate_loss, iplate_iplate_grad = jax.value_and_grad(iplate_iplate_loss_fn)( + params + ) + plate_iplate_loss, plate_iplate_grad = jax.value_and_grad(plate_iplate_loss_fn)( + params + ) + iplate_plate_loss, iplate_plate_grad = jax.value_and_grad(iplate_plate_loss_fn)( + params + ) + plate_plate_loss, plate_plate_grad = jax.value_and_grad(plate_plate_loss_fn)(params) + + assert_equal(iplate_iplate_loss, plate_iplate_loss, prec=1e-4) + assert_equal(iplate_iplate_grad, plate_iplate_grad, prec=1e-4) + assert_equal(iplate_iplate_loss, iplate_plate_loss, prec=1e-4) + assert_equal(iplate_iplate_grad, iplate_plate_grad, prec=1e-4) + assert_equal(iplate_iplate_loss, plate_plate_loss, prec=1e-4) + assert_equal(iplate_iplate_grad, plate_plate_grad, prec=1e-4) + + +@pytest.mark.parametrize("guide_scale", [1]) +@pytest.mark.parametrize("model_scale", [1]) +@pytest.mark.parametrize("outer_vectorized", [False, True]) +@pytest.mark.parametrize("inner_vectorized", [False, True]) +def test_elbo_enumerate_plates_8( + model_scale, guide_scale, inner_vectorized, outer_vectorized +): + # Guide Model + # a + # +-----------|--------+ + # | M=2 +---|------+ | + # | | V N=2 | | + # | b ----> c | | + # | +----------+ | + # +--------------------+ + params = {} + params["model_probs_a"] = jnp.array([0.45, 0.55]) + params["model_probs_b"] = jnp.array([0.6, 0.4]) + params["model_probs_c"] = jnp.array( + [[[0.4, 0.5, 0.1], [0.3, 0.5, 0.2]], [[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]] + ) + params["guide_probs_b"] = jnp.array([0.8, 0.2]) + data = jnp.array([[0, 1], [0, 2]]) + + @config_enumerate + @handlers.scale(scale=model_scale) + def model_plate_plate(params): + probs_a = pyro.param( + "probs_a", params["model_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["model_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["model_probs_c"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("outer", 2): + b = pyro.sample("b", dist.Categorical(probs_b)) + with pyro.plate("inner", 2): + pyro.sample("c", dist.Categorical(Vindex(probs_c)[a, b]), obs=data) + + @config_enumerate + @handlers.scale(scale=model_scale) + def model_iplate_plate(params): + probs_a = pyro.param( + "probs_a", params["model_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["model_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["model_probs_c"], constraint=constraints.simplex + ) + inner = pyro.plate("inner", 2) + a = pyro.sample("a", dist.Categorical(probs_a)) + for i in range(2): + b = pyro.sample(f"b_{i}", dist.Categorical(probs_b)) + with inner: + pyro.sample( + f"c_{i}", + dist.Categorical(Vindex(probs_c)[a, b]), + obs=data[:, i], + ) + + @config_enumerate + @handlers.scale(scale=model_scale) + def model_plate_iplate(params): + probs_a = pyro.param( + "probs_a", params["model_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["model_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["model_probs_c"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("outer", 2): + b = pyro.sample("b", dist.Categorical(probs_b)) + for j in range(2): + pyro.sample( + f"c_{j}", + dist.Categorical(Vindex(probs_c)[a, b]), + obs=data[j], + ) + + @config_enumerate + @handlers.scale(scale=model_scale) + def model_iplate_iplate(params): + probs_a = pyro.param( + "probs_a", params["model_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["model_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["model_probs_c"], constraint=constraints.simplex + ) + inner = range(2) + a = pyro.sample("a", dist.Categorical(probs_a)) + for i in range(2): + b = pyro.sample(f"b_{i}", dist.Categorical(probs_b)) + for j in inner: + pyro.sample( + f"c_{i}_{j}", + dist.Categorical(Vindex(probs_c)[a, b]), + obs=data[j, i], + ) + + @config_enumerate + @handlers.scale(scale=guide_scale) + def guide_plate(params): + probs_b = pyro.param( + "probs_b", params["guide_probs_b"], constraint=constraints.simplex + ) + with pyro.plate("outer", 2): + pyro.sample("b", dist.Categorical(probs_b)) + + @config_enumerate + @handlers.scale(scale=guide_scale) + def guide_iplate(params): + probs_b = pyro.param( + "probs_b", params["guide_probs_b"], constraint=constraints.simplex + ) + for i in range(2): + pyro.sample(f"b_{i}", dist.Categorical(probs_b)) + + def iplate_iplate_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss( + random.PRNGKey(0), {}, model_iplate_iplate, guide_iplate, params + ) + + def plate_iplate_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + return elbo.loss(random.PRNGKey(0), {}, model_plate_iplate, guide_plate, params) + + def iplate_plate_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + return elbo.loss( + random.PRNGKey(0), {}, model_iplate_plate, guide_iplate, params + ) + + def plate_plate_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=2) + return elbo.loss(random.PRNGKey(0), {}, model_plate_plate, guide_plate, params) + + expected_loss, expected_grad = jax.value_and_grad(iplate_iplate_loss_fn)(params) + if inner_vectorized: + if outer_vectorized: + with pytest.raises( + ValueError, + match="Expected model enumeration to be no more global than guide", + ): + actual_loss, actual_grad = jax.value_and_grad(plate_plate_loss_fn)( + params + ) + assert_equal(actual_loss, expected_loss, prec=1e-4) + assert_equal(actual_grad, expected_grad, prec=1e-4) + else: + actual_loss, actual_grad = jax.value_and_grad(iplate_plate_loss_fn)(params) + assert_equal(actual_loss, expected_loss, prec=1e-4) + assert_equal(actual_grad, expected_grad, prec=1e-4) + else: + if outer_vectorized: + with pytest.raises( + ValueError, + match="Expected model enumeration to be no more global than guide", + ): + actual_loss, actual_grad = jax.value_and_grad(plate_iplate_loss_fn)( + params + ) + assert_equal(actual_loss, expected_loss, prec=1e-4) + assert_equal(actual_grad, expected_grad, prec=1e-4) + else: + actual_loss, actual_grad = jax.value_and_grad(iplate_iplate_loss_fn)(params) + assert_equal(actual_loss, expected_loss, prec=1e-4) + assert_equal(actual_grad, expected_grad, prec=1e-4) + + +def test_elbo_enumerate_plate_9(): + # Model Guide + # a + # +-------|-------+ + # | M=2 V | + # | b -> c | + # +---------------+ + params = {} + params["model_probs_a"] = jnp.array([0.45, 0.55]) + params["model_probs_b"] = jnp.array([[0.3, 0.7], [0.6, 0.4]]) + params["model_probs_c"] = jnp.array([[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]) + params["guide_probs_a"] = jnp.array([0.45, 0.55]) + params["guide_probs_b"] = jnp.array([[0.3, 0.7], [0.8, 0.2]]) + + data = jnp.array([1, 2]) + + @config_enumerate + def model_plate(params): + probs_a = pyro.param( + "probs_a", params["model_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["model_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["model_probs_c"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("b_axis", 2): + b = pyro.sample("b", dist.Categorical(probs_b[a])) + pyro.sample("c", dist.Categorical(probs_c[b]), obs=data) + + @config_enumerate + def guide_plate(params): + probs_a = pyro.param( + "probs_a", params["guide_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["guide_probs_b"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("b_axis", 2): + pyro.sample("b", dist.Categorical(probs_b[a])) + + @config_enumerate + def model_iplate(params): + probs_a = pyro.param( + "probs_a", params["model_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["model_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["model_probs_c"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + for i in range(2): + b = pyro.sample(f"b_{i}", dist.Categorical(probs_b[a])) + pyro.sample(f"c_{i}", dist.Categorical(probs_c[b]), obs=data[i]) + + @config_enumerate + def guide_iplate(params): + probs_a = pyro.param( + "probs_a", params["guide_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["guide_probs_b"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + for i in range(2): + pyro.sample(f"b_{i}", dist.Categorical(probs_b[a])) + + def expected_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss(random.PRNGKey(0), {}, model_iplate, guide_iplate, params) + + def actual_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=1) + return elbo.loss(random.PRNGKey(0), {}, model_plate, guide_plate, params) + + expected_loss, expected_grad = jax.value_and_grad(expected_loss_fn)(params) + actual_loss, actual_grad = jax.value_and_grad(actual_loss_fn)(params) + + assert_equal(expected_loss, actual_loss, prec=1e-5) + assert_equal(expected_grad, actual_grad, prec=1e-5) + + +def test_elbo_enumerate_plate_10(): + # Model + # a -> [ [ bij -> cij ] ] + # Guide + # a -> [ [ bij ] ] + params = {} + params["model_probs_a"] = jnp.array([0.45, 0.55]) + params["model_probs_b"] = jnp.array([[0.3, 0.7], [0.6, 0.4]]) + params["model_probs_c"] = jnp.array([[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]) + params["guide_probs_a"] = jnp.array([0.45, 0.55]) + params["guide_probs_b"] = jnp.array([[0.3, 0.7], [0.8, 0.2]]) + data = jnp.array([[0, 1, 2], [1, 2, 2]]) + + @config_enumerate + def model_plate(params): + probs_a = pyro.param( + "probs_a", params["model_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["model_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["model_probs_c"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("i", 2, dim=-2): + with pyro.plate("j", 3, dim=-1): + b = pyro.sample("b", dist.Categorical(probs_b[a])) + pyro.sample("c", dist.Categorical(probs_c[b]), obs=data) + + @config_enumerate + def guide_plate(params): + probs_a = pyro.param( + "probs_a", params["guide_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["guide_probs_b"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("i", 2, dim=-2): + with pyro.plate("j", 3, dim=-1): + pyro.sample("b", dist.Categorical(probs_b[a])) + + @config_enumerate + def model_iplate(params): + probs_a = pyro.param( + "probs_a", params["model_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["model_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["model_probs_c"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + for i in range(2): + for j in range(3): + b = pyro.sample(f"b_{i}_{j}", dist.Categorical(probs_b[a])) + pyro.sample(f"c_{i}_{j}", dist.Categorical(probs_c[b]), obs=data[i, j]) + + @config_enumerate + def guide_iplate(params): + probs_a = pyro.param( + "probs_a", params["guide_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["guide_probs_b"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + for i in range(2): + for j in range(3): + pyro.sample(f"b_{i}_{j}", dist.Categorical(probs_b[a])) + + def expected_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss(random.PRNGKey(0), {}, model_iplate, guide_iplate, params) + + def actual_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=2) + return elbo.loss(random.PRNGKey(0), {}, model_plate, guide_plate, params) + + expected_loss, expected_grad = jax.value_and_grad(expected_loss_fn)(params) + actual_loss, actual_grad = jax.value_and_grad(actual_loss_fn)(params) + + assert_equal(expected_loss, actual_loss, prec=1e-5) + assert_equal(expected_grad, actual_grad, prec=1e-5) + + +def test_elbo_enumerate_plate_11(): + # Model + # [ ai -> [ bij -> cij ] ] + # Guide + # [ ai -> [ bij ] ] + params = {} + params["model_probs_a"] = jnp.array([0.45, 0.55]) + params["model_probs_b"] = jnp.array([[0.3, 0.7], [0.6, 0.4]]) + params["model_probs_c"] = jnp.array([[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]) + params["guide_probs_a"] = jnp.array([0.45, 0.55]) + params["guide_probs_b"] = jnp.array([[0.3, 0.7], [0.8, 0.2]]) + data = jnp.array([[0, 1, 2], [1, 2, 2]]) + + @config_enumerate + def model_plate(params): + probs_a = pyro.param( + "probs_a", params["model_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["model_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["model_probs_c"], constraint=constraints.simplex + ) + with pyro.plate("i", 2, dim=-2): + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("j", 3, dim=-1): + b = pyro.sample("b", dist.Categorical(probs_b[a])) + pyro.sample("c", dist.Categorical(probs_c[b]), obs=data) + + @config_enumerate + def guide_plate(params): + probs_a = pyro.param( + "probs_a", params["guide_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["guide_probs_b"], constraint=constraints.simplex + ) + with pyro.plate("i", 2, dim=-2): + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("j", 3, dim=-1): + pyro.sample("b", dist.Categorical(probs_b[a])) + + @config_enumerate + def model_iplate(params): + probs_a = pyro.param( + "probs_a", params["model_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["model_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["model_probs_c"], constraint=constraints.simplex + ) + for i in range(2): + a = pyro.sample(f"a_{i}", dist.Categorical(probs_a)) + for j in range(3): + b = pyro.sample(f"b_{i}_{j}", dist.Categorical(probs_b[a])) + pyro.sample(f"c_{i}_{j}", dist.Categorical(probs_c[b]), obs=data[i, j]) + + @config_enumerate + def guide_iplate(params): + probs_a = pyro.param( + "probs_a", params["guide_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["guide_probs_b"], constraint=constraints.simplex + ) + for i in range(2): + a = pyro.sample(f"a_{i}", dist.Categorical(probs_a)) + for j in range(3): + pyro.sample(f"b_{i}_{j}", dist.Categorical(probs_b[a])) + + def expected_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss(random.PRNGKey(0), {}, model_iplate, guide_iplate, params) + + def actual_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=2) + return elbo.loss(random.PRNGKey(0), {}, model_plate, guide_plate, params) + + expected_loss, expected_grad = jax.value_and_grad(expected_loss_fn)(params) + actual_loss, actual_grad = jax.value_and_grad(actual_loss_fn)(params) + + assert_equal(expected_loss, actual_loss, prec=1e-5) + assert_equal(expected_grad, actual_grad, prec=1e-5) + + +def test_elbo_enumerate_plate_12(): + # Model + # a -> [ bi -> [ cij -> dij ] ] + # Guide + # a -> [ bi -> [ cij ] ] + params = {} + params["model_probs_a"] = jnp.array([0.45, 0.55]) + params["model_probs_b"] = jnp.array([[0.3, 0.7], [0.6, 0.4]]) + params["model_probs_c"] = jnp.array([[0.3, 0.4, 0.3], [0.4, 0.4, 0.2]]) + params["model_probs_d"] = jnp.array( + [[0.1, 0.6, 0.3], [0.3, 0.4, 0.3], [0.4, 0.4, 0.2]] + ) + params["guide_probs_a"] = jnp.array([0.45, 0.55]) + params["guide_probs_b"] = jnp.array([[0.3, 0.7], [0.8, 0.2]]) + params["guide_probs_c"] = jnp.array([[0.3, 0.3, 0.4], [0.2, 0.4, 0.4]]) + data = jnp.array([[0, 1, 2], [1, 2, 2]]) + + @config_enumerate + def model_plate(params): + probs_a = pyro.param( + "probs_a", params["model_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["model_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["model_probs_c"], constraint=constraints.simplex + ) + probs_d = pyro.param( + "probs_d", params["model_probs_d"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("i", 2, dim=-2): + b = pyro.sample("b", dist.Categorical(probs_b[a])) + with pyro.plate("j", 3, dim=-1): + c = pyro.sample("c", dist.Categorical(probs_c[b])) + pyro.sample("d", dist.Categorical(probs_d[c]), obs=data) + + @config_enumerate + def guide_plate(params): + probs_a = pyro.param( + "probs_a", params["guide_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["guide_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["guide_probs_c"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("i", 2, dim=-2): + b = pyro.sample("b", dist.Categorical(probs_b[a])) + with pyro.plate("j", 3, dim=-1): + pyro.sample("c", dist.Categorical(probs_c[b])) + + @config_enumerate + def model_iplate(params): + probs_a = pyro.param( + "probs_a", params["model_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["model_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["model_probs_c"], constraint=constraints.simplex + ) + probs_d = pyro.param( + "probs_d", params["model_probs_d"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + for i in range(2): + b = pyro.sample(f"b_{i}", dist.Categorical(probs_b[a])) + for j in range(3): + c = pyro.sample(f"c_{i}_{j}", dist.Categorical(probs_c[b])) + pyro.sample(f"d_{i}_{j}", dist.Categorical(probs_d[c]), obs=data[i, j]) + + @config_enumerate + def guide_iplate(params): + probs_a = pyro.param( + "probs_a", params["guide_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["guide_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["guide_probs_c"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + for i in range(2): + b = pyro.sample(f"b_{i}", dist.Categorical(probs_b[a])) + for j in range(3): + pyro.sample(f"c_{i}_{j}", dist.Categorical(probs_c[b])) + + def expected_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss(random.PRNGKey(0), {}, model_iplate, guide_iplate, params) + + def actual_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=2) + return elbo.loss(random.PRNGKey(0), {}, model_plate, guide_plate, params) + + expected_loss, expected_grad = jax.value_and_grad(expected_loss_fn)(params) + actual_loss, actual_grad = jax.value_and_grad(actual_loss_fn)(params) + + assert_equal(expected_loss, actual_loss, prec=1e-5) + assert_equal(expected_grad, actual_grad, prec=1e-5) + + +def test_elbo_enumerate_plate_13(): + # Model + # a -> [ cj -> [ dij ] ] + # | + # v + # [ bi ] + # Guide + # a -> [ cj ] + # | + # v + # [ bi ] + params = {} + params["model_probs_a"] = jnp.array([0.45, 0.55]) + params["model_probs_b"] = jnp.array([[0.3, 0.7], [0.6, 0.4]]) + params["model_probs_c"] = jnp.array([[0.3, 0.7], [0.4, 0.6]]) + params["model_probs_d"] = jnp.array( + [[0.1, 0.6, 0.3], [0.3, 0.4, 0.3], [0.4, 0.4, 0.2]] + ) + params["guide_probs_a"] = jnp.array([0.45, 0.55]) + params["guide_probs_b"] = jnp.array([[0.3, 0.7], [0.8, 0.2]]) + params["guide_probs_c"] = jnp.array([[0.2, 0.8], [0.9, 0.1]]) + data = jnp.array([[0, 1, 2], [1, 2, 2]]) + + @config_enumerate + def model_plate(params): + probs_a = pyro.param( + "probs_a", params["model_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["model_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["model_probs_c"], constraint=constraints.simplex + ) + probs_d = pyro.param( + "probs_d", params["model_probs_d"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("i", 2, dim=-2): + pyro.sample("b", dist.Categorical(probs_b[a])) + with pyro.plate("j", 3, dim=-1): + c = pyro.sample("c", dist.Categorical(probs_c[a])) + pyro.sample("d", dist.Categorical(probs_d[c]), obs=data) + + @config_enumerate + def guide_plate(params): + probs_a = pyro.param( + "probs_a", params["guide_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["guide_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["guide_probs_c"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("i", 2, dim=-2): + pyro.sample("b", dist.Categorical(probs_b[a])) + with pyro.plate("j", 3, dim=-1): + pyro.sample("c", dist.Categorical(probs_c[a])) + + @config_enumerate + def model_iplate(params): + probs_a = pyro.param( + "probs_a", params["model_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["model_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["model_probs_c"], constraint=constraints.simplex + ) + probs_d = pyro.param( + "probs_d", params["model_probs_d"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + for i in range(2): + pyro.sample(f"b_{i}", dist.Categorical(probs_b[a])) + for j in range(3): + c = pyro.sample(f"c_{i}_{j}", dist.Categorical(probs_c[a])) + pyro.sample(f"d_{i}_{j}", dist.Categorical(probs_d[c]), obs=data[i, j]) + + @config_enumerate + def guide_iplate(params): + probs_a = pyro.param( + "probs_a", params["guide_probs_a"], constraint=constraints.simplex + ) + probs_b = pyro.param( + "probs_b", params["guide_probs_b"], constraint=constraints.simplex + ) + probs_c = pyro.param( + "probs_c", params["guide_probs_c"], constraint=constraints.simplex + ) + a = pyro.sample("a", dist.Categorical(probs_a)) + for i in range(2): + pyro.sample(f"b_{i}", dist.Categorical(probs_b[a])) + for j in range(3): + pyro.sample(f"c_{i}_{j}", dist.Categorical(probs_c[a])) + + def expected_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=0) + return elbo.loss(random.PRNGKey(0), {}, model_iplate, guide_iplate, params) + + def actual_loss_fn(params): + elbo = infer.TraceEnum_ELBO(max_plate_nesting=2) + return elbo.loss(random.PRNGKey(0), {}, model_plate, guide_plate, params) + + expected_loss, expected_grad = jax.value_and_grad(expected_loss_fn)(params) + actual_loss, actual_grad = jax.value_and_grad(actual_loss_fn)(params) + + assert_equal(expected_loss, actual_loss, prec=1e-5) + assert_equal(expected_grad, actual_grad, prec=1e-5) diff --git a/test/infer/test_gradient.py b/test/infer/test_gradient.py new file mode 100644 index 000000000..9f9a2eb4d --- /dev/null +++ b/test/infer/test_gradient.py @@ -0,0 +1,139 @@ +# Copyright Contributors to the Pyro project. +# SPDX-License-Identifier: Apache-2.0 + +import logging + +import numpy as np +import pytest + +import jax +from jax import random +import jax.numpy as jnp + +import numpyro as pyro +from numpyro import infer +from numpyro.contrib.funsor import config_enumerate +import numpyro.distributions as dist +from numpyro.distributions import constraints +from numpyro.ops.indexing import Vindex + +logger = logging.getLogger(__name__) + + +def assert_equal(a, b, prec=0): + return jax.tree_util.tree_map( + lambda a, b: np.testing.assert_allclose(a, b, atol=prec), a, b + ) + + +def model_0(data, params): + with pyro.plate("data", len(data)): + z = pyro.sample("z", dist.Categorical(jnp.array([0.3, 0.7]))) + pyro.sample("x", dist.Normal(z, 1), obs=data) + + +def guide_0(data, params): + probs = pyro.param("probs", params["probs"], constraint=constraints.simplex) + with pyro.plate("data", len(data)): + pyro.sample("z", dist.Categorical(probs)) + + +params_0 = {"probs": jnp.array([[0.4, 0.6], [0.5, 0.5]])} + + +def model_1(data, params): + a = pyro.sample("a", dist.Categorical(jnp.array([0.3, 0.7]))) + with pyro.plate("data", len(data)): + probs_b = jnp.array([[0.1, 0.9], [0.2, 0.8]]) + b = pyro.sample("b", dist.Categorical(probs_b[a])) + pyro.sample("c", dist.Normal(b, 1), obs=data) + + +def guide_1(data, params): + probs_a = pyro.param("probs_a", params["probs_a"], constraint=constraints.simplex) + probs_b = pyro.param("probs_b", params["probs_b"], constraint=constraints.simplex) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("data", len(data)) as idx: + pyro.sample("b", dist.Categorical(Vindex(probs_b)[a, idx])) + + +params_1 = { + "probs_a": jnp.array([0.5, 0.5]), + "probs_b": jnp.array([[[0.5, 0.5], [0.6, 0.4]], [[0.4, 0.6], [0.35, 0.65]]]), +} + + +def model_2(data, params): + prob_b = jnp.array([[0.3, 0.7], [0.4, 0.6]]) + prob_c = jnp.array([[0.5, 0.5], [0.6, 0.4]]) + prob_d = jnp.array([[0.2, 0.8], [0.3, 0.7]]) + prob_e = jnp.array([[0.5, 0.5], [0.1, 0.9]]) + a = pyro.sample("a", dist.Categorical(jnp.array([0.3, 0.7]))) + with pyro.plate("data", len(data)): + b = pyro.sample("b", dist.Categorical(prob_b[a])) + c = pyro.sample("c", dist.Categorical(prob_c[b])) + pyro.sample("d", dist.Categorical(prob_d[b])) + pyro.sample("e", dist.Categorical(prob_e[c]), obs=data) + + +def guide_2(data, params): + probs_a = pyro.param("probs_a", params["probs_a"], constraint=constraints.simplex) + probs_b = pyro.param("probs_b", params["probs_b"], constraint=constraints.simplex) + probs_c = pyro.param("probs_c", params["probs_c"], constraint=constraints.simplex) + probs_d = pyro.param("probs_d", params["probs_d"], constraint=constraints.simplex) + a = pyro.sample("a", dist.Categorical(probs_a)) + with pyro.plate("data", len(data)) as idx: + b = pyro.sample("b", dist.Categorical(probs_b[a])) + pyro.sample("c", dist.Categorical(Vindex(probs_c)[b, idx])) + pyro.sample("d", dist.Categorical(Vindex(probs_d)[b, idx])) + + +params_2 = { + "probs_a": jnp.array([0.5, 0.5]), + "probs_b": jnp.array([[0.4, 0.6], [0.3, 0.7]]), + "probs_c": jnp.array([[[0.3, 0.7], [0.8, 0.2]], [[0.2, 0.8], [0.5, 0.5]]]), + "probs_d": jnp.array([[[0.2, 0.8], [0.9, 0.1]], [[0.1, 0.9], [0.4, 0.6]]]), +} + + +@pytest.mark.parametrize( + "model,guide,params,data", + [ + (model_0, guide_0, params_0, jnp.array([-0.5, 2.0])), + (model_1, guide_1, params_1, jnp.array([-0.5, 2.0])), + (model_2, guide_2, params_2, jnp.array([0, 1])), + ], +) +def test_gradient(model, guide, params, data): + transform = dist.biject_to(dist.constraints.simplex) + params_raw = jax.tree_util.tree_map(transform.inv, params) + + # Expected grads based on exact integration + elbo = infer.TraceEnum_ELBO( + max_plate_nesting=1, # set this to ensure rng agrees across runs + ) + + def expected_loss_fn(params_raw): + params = jax.tree_util.tree_map(transform, params_raw) + return elbo.loss( + random.PRNGKey(0), {}, model, config_enumerate(guide), data, params + ) + + expected_loss, expected_grads = jax.value_and_grad(expected_loss_fn)(params_raw) + + # Actual grads averaged over num_particles + elbo = infer.TraceGraph_ELBO( + num_particles=10_000, + ) + + def actual_loss_fn(params_raw): + params = jax.tree_util.tree_map(transform, params_raw) + return elbo.loss(random.PRNGKey(0), {}, model, guide, data, params) + + actual_loss, actual_grads = jax.value_and_grad(actual_loss_fn)(params_raw) + + for name in sorted(params): + logger.info("expected {} = {}".format(name, expected_grads[name])) + logger.info("actual {} = {}".format(name, actual_grads[name])) + + assert_equal(actual_grads, expected_grads, prec=0.02)