From 92c6274e9f1ac81d2000f89034a02f48a22abd20 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Mon, 1 Jul 2024 00:08:23 -0400 Subject: [PATCH 1/2] allow to use NeuTra with plate --- numpyro/infer/reparam.py | 6 ++++++ numpyro/primitives.py | 6 ++++++ test/infer/test_reparam.py | 18 +++++++++++++++++- 3 files changed, 29 insertions(+), 1 deletion(-) diff --git a/numpyro/infer/reparam.py b/numpyro/infer/reparam.py index 9fb85d6d0..a5e482a02 100644 --- a/numpyro/infer/reparam.py +++ b/numpyro/infer/reparam.py @@ -281,9 +281,15 @@ def __call__(self, name, fn, obs): compute_density = numpyro.get_mask() is not False if not self._x_unconstrained: # On first sample site. # Sample a shared latent. + model_plates = { + msg["name"] + for msg in self.guide.prototype_trace.values() + if msg["type"] == "plate" + } z_unconstrained = numpyro.sample( "{}_shared_latent".format(self.guide.prefix), self.guide.get_base_dist().mask(False), + infer={"block_plates": model_plates}, ) # Differentiably transform. diff --git a/numpyro/primitives.py b/numpyro/primitives.py index ac02a8856..cf84b9512 100644 --- a/numpyro/primitives.py +++ b/numpyro/primitives.py @@ -530,6 +530,12 @@ def process_message(self, msg): ) return + if ( + "block_plates" in msg.get("infer", {}) + and self.name in msg["infer"]["block_plates"] + ): + return + cond_indep_stack = msg["cond_indep_stack"] frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size) cond_indep_stack.append(frame) diff --git a/test/infer/test_reparam.py b/test/infer/test_reparam.py index 456ff1659..f1d72f742 100644 --- a/test/infer/test_reparam.py +++ b/test/infer/test_reparam.py @@ -15,7 +15,7 @@ from numpyro.distributions.transforms import AffineTransform, ExpTransform import numpyro.handlers as handlers from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO -from numpyro.infer.autoguide import AutoIAFNormal +from numpyro.infer.autoguide import AutoDiagonalNormal, AutoIAFNormal from numpyro.infer.reparam import ( CircularReparam, ExplicitReparam, @@ -228,6 +228,22 @@ def test_neutra_reparam_unobserved_model(): reparam_model(data=None) +def test_neutra_reparam_with_plate(): + def model(): + with numpyro.plate("N", 3, dim=-1): + x = numpyro.sample("x", dist.Normal(0, 1)) + assert x.shape == (3,) + + guide = AutoDiagonalNormal(model) + svi = SVI(model, guide, Adam(1e-3), Trace_ELBO()) + svi_state = svi.init(random.PRNGKey(0)) + params = svi.get_params(svi_state) + neutra = NeuTraReparam(guide, params) + reparam_model = neutra.reparam(model) + with handlers.seed(rng_seed=0): + reparam_model() + + @pytest.mark.parametrize("shape", [(), (4,), (3, 2)], ids=str) @pytest.mark.parametrize("centered", [0.0, 0.6, 1.0, None]) @pytest.mark.parametrize("dist_type", ["Normal", "StudentT"]) From ca732fede4a34c07e516c7cb3a86115c26e4c3c4 Mon Sep 17 00:00:00 2001 From: Du Phan Date: Tue, 2 Jul 2024 06:52:22 -0400 Subject: [PATCH 2/2] Fix typo in reparam.py --- numpyro/infer/reparam.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/numpyro/infer/reparam.py b/numpyro/infer/reparam.py index a5e482a02..60be9fa7a 100644 --- a/numpyro/infer/reparam.py +++ b/numpyro/infer/reparam.py @@ -226,7 +226,7 @@ class NeuTraReparam(Reparam): # Step 2. Use trained guide in NeuTra MCMC neutra = NeuTraReparam(guide) - model = netra.reparam(model) + model = neutra.reparam(model) nuts = NUTS(model) # ...now use the model in HMC or NUTS...