Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow to use NeuTra on models with plates #1826

Merged
merged 2 commits into from
Jul 2, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions numpyro/infer/reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,9 +281,15 @@ def __call__(self, name, fn, obs):
compute_density = numpyro.get_mask() is not False
if not self._x_unconstrained: # On first sample site.
# Sample a shared latent.
model_plates = {
msg["name"]
for msg in self.guide.prototype_trace.values()
if msg["type"] == "plate"
}
z_unconstrained = numpyro.sample(
"{}_shared_latent".format(self.guide.prefix),
self.guide.get_base_dist().mask(False),
infer={"block_plates": model_plates},
)

# Differentiably transform.
Expand Down
6 changes: 6 additions & 0 deletions numpyro/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,12 @@ def process_message(self, msg):
)
return

if (
"block_plates" in msg.get("infer", {})
and self.name in msg["infer"]["block_plates"]
):
return

cond_indep_stack = msg["cond_indep_stack"]
frame = CondIndepStackFrame(self.name, self.dim, self.subsample_size)
cond_indep_stack.append(frame)
Expand Down
18 changes: 17 additions & 1 deletion test/infer/test_reparam.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from numpyro.distributions.transforms import AffineTransform, ExpTransform
import numpyro.handlers as handlers
from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO
from numpyro.infer.autoguide import AutoIAFNormal
from numpyro.infer.autoguide import AutoDiagonalNormal, AutoIAFNormal
from numpyro.infer.reparam import (
CircularReparam,
ExplicitReparam,
Expand Down Expand Up @@ -228,6 +228,22 @@ def test_neutra_reparam_unobserved_model():
reparam_model(data=None)


def test_neutra_reparam_with_plate():
def model():
with numpyro.plate("N", 3, dim=-1):
x = numpyro.sample("x", dist.Normal(0, 1))
assert x.shape == (3,)

guide = AutoDiagonalNormal(model)
svi = SVI(model, guide, Adam(1e-3), Trace_ELBO())
svi_state = svi.init(random.PRNGKey(0))
params = svi.get_params(svi_state)
neutra = NeuTraReparam(guide, params)
reparam_model = neutra.reparam(model)
with handlers.seed(rng_seed=0):
reparam_model()


@pytest.mark.parametrize("shape", [(), (4,), (3, 2)], ids=str)
@pytest.mark.parametrize("centered", [0.0, 0.6, 1.0, None])
@pytest.mark.parametrize("dist_type", ["Normal", "StudentT"])
Expand Down
Loading