From 4a4e1a97c8039d592d3eb7785df33f3e481b04ea Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Wed, 19 Jun 2024 10:48:13 +0200 Subject: [PATCH 1/7] filter oout tests waiting for next tfp release --- test/contrib/test_tfp.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/contrib/test_tfp.py b/test/contrib/test_tfp.py index ab3adf64c..7670dbd5b 100644 --- a/test/contrib/test_tfp.py +++ b/test/contrib/test_tfp.py @@ -35,6 +35,7 @@ def f(x): assert res.scale == 1 +@pytest.mark.skip(reason="Waiting for the next tfp release") @pytest.mark.filterwarnings("ignore:can't resolve package") def test_transformed_distributions(): from tensorflow_probability.substrates.jax import ( @@ -113,6 +114,7 @@ def make_kernel_fn(target_log_prob_fn): ) +@pytest.mark.skip(reason="Waiting for the next tfp release") @pytest.mark.parametrize( "kernel, kwargs", [ @@ -243,6 +245,7 @@ def test_sample_tfp_distributions(): # test that sampling from unwrapped tensorflow_probability distributions works as # expected using numpyro.sample primitive +@pytest.mark.skip(reason="Waiting for the next tfp release") @pytest.mark.parametrize( "dist,args", [ @@ -270,6 +273,7 @@ def test_sample_unwrapped_tfp_distributions(dist, args): # test mixture distributions +@pytest.mark.skip(reason="Waiting for the next tfp release") def test_sample_unwrapped_mixture_same_family(): from tensorflow_probability.substrates.jax import distributions as tfd From cbed1f53b54ba0eff9d179c2bce938b95a046ae1 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Sun, 30 Jun 2024 22:22:35 +0200 Subject: [PATCH 2/7] fix issue 1446 --- numpyro/infer/util.py | 2 +- test/contrib/test_module.py | 17 ++++++++++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 3775893a7..1f649bc98 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -660,7 +660,7 @@ def initialize_model( data={ k: site["value"] for k, site in model_trace.items() - if site["type"] in ["param"] + if site["type"] in ["param", "mutable"] }, ) constrained_values = { diff --git a/test/contrib/test_module.py b/test/contrib/test_module.py index d3f27f17e..5c88c85d5 100644 --- a/test/contrib/test_module.py +++ b/test/contrib/test_module.py @@ -21,7 +21,8 @@ random_haiku_module, ) import numpyro.distributions as dist -from numpyro.infer import MCMC, NUTS +from numpyro.infer import MCMC, NUTS, SVI, Trace_ELBO +from numpyro.infer.autoguide import AutoDelta pytestmark = pytest.mark.filterwarnings( "ignore:jax.tree_.+ is deprecated:FutureWarning" @@ -242,7 +243,7 @@ def model(): nn = haiku_module("nn", transform(fn), apply_rng=dropout, input_shape=(4, 3)) x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2)) if dropout: - y = nn(numpyro.prng_key(), x) + y = nn(random.PRNGKey(0), x) else: y = nn(x) numpyro.deterministic("y", y) @@ -256,6 +257,11 @@ def model(): else: assert set(tr.keys()) == {"nn$params", "x", "y"} + # test svi + guide = AutoDelta(model) + svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) + svi.run(random.PRNGKey(100), 10) + @pytest.mark.parametrize("dropout", [True, False]) @pytest.mark.parametrize("batchnorm", [True, False]) @@ -287,7 +293,7 @@ def model(): ) x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2)) if dropout: - y = net(x, rngs={"dropout": numpyro.prng_key()}) + y = net(x, rngs={"dropout": random.PRNGKey(0)}) else: y = net(x) numpyro.deterministic("y", y) @@ -300,3 +306,8 @@ def model(): assert tr["nn$state"]["type"] == "mutable" else: assert set(tr.keys()) == {"nn$params", "x", "y"} + + # test svi + guide = AutoDelta(model) + svi = SVI(model, guide, numpyro.optim.Adam(0.01), Trace_ELBO()) + svi.run(random.PRNGKey(100), 10) From a37f7cea9d252ea2cc362a3a9346ab7f90de597a Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Mon, 1 Jul 2024 10:03:36 +0200 Subject: [PATCH 3/7] add feddback (not working) --- numpyro/handlers.py | 3 ++- numpyro/infer/util.py | 26 ++++++++++---------------- test/contrib/test_module.py | 4 ++-- 3 files changed, 14 insertions(+), 19 deletions(-) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 6e13aeb70..ba710b59b 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -805,7 +805,8 @@ def __init__(self, fn=None, data=None, substitute_fn=None): def process_message(self, msg): if ( - msg["type"] not in ("sample", "param", "mutable", "plate", "deterministic") + msg["type"] + not in ("sample", "param", "mutable", "plate", "deterministic", "prng_key") ) or msg.get("_control_flow_done", False): if msg["type"] == "control_flow": if self.data is not None: diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 1f649bc98..96211f1f2 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -653,16 +653,17 @@ def initialize_model( has_enumerate_support, model_trace, ) = _get_model_transforms(substituted_model, model_args, model_kwargs) + # substitute param sites from model_trace to model so # we don't need to generate again parameters of `numpyro.module` - model = substitute( - model, - data={ - k: site["value"] - for k, site in model_trace.items() - if site["type"] in ["param", "mutable"] - }, - ) + def substitute_fn(site): + if site["type"] in ["param", "mutable"]: + return site["value"] + elif site["type"] == "prng_key": + return random.PRNGKey(0) + + model = substitute(model, substitute_fn=substitute_fn) + constrained_values = { k: v["value"] for k, v in model_trace.items() @@ -701,14 +702,7 @@ def initialize_model( prototype_params = transform_fn(inv_transforms, constrained_values, invert=True) (init_params, pe, grad), is_valid = find_valid_initial_params( rng_key, - substitute( - model, - data={ - k: site["value"] - for k, site in model_trace.items() - if site["type"] in ["plate"] - }, - ), + substitute(model, substitute_fn=substitute_fn), init_strategy=init_strategy, enum=has_enumerate_support, model_args=model_args, diff --git a/test/contrib/test_module.py b/test/contrib/test_module.py index 5c88c85d5..a1342507f 100644 --- a/test/contrib/test_module.py +++ b/test/contrib/test_module.py @@ -243,7 +243,7 @@ def model(): nn = haiku_module("nn", transform(fn), apply_rng=dropout, input_shape=(4, 3)) x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2)) if dropout: - y = nn(random.PRNGKey(0), x) + y = nn(numpyro.prng_key(), x) else: y = nn(x) numpyro.deterministic("y", y) @@ -293,7 +293,7 @@ def model(): ) x = numpyro.sample("x", dist.Normal(0, 1).expand([4, 3]).to_event(2)) if dropout: - y = net(x, rngs={"dropout": random.PRNGKey(0)}) + y = net(x, rngs={"dropout": numpyro.prng_key()}) else: y = net(x) numpyro.deterministic("y", y) From 232c89c08954a19d5578198137f567f3e62081a4 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Mon, 1 Jul 2024 15:15:02 +0200 Subject: [PATCH 4/7] feedbackl 2 --- numpyro/handlers.py | 2 +- numpyro/infer/util.py | 26 ++++++++++++++++++++------ 2 files changed, 21 insertions(+), 7 deletions(-) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index ba710b59b..2e9481c07 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -818,7 +818,7 @@ def process_message(self, msg): return if self.data is not None: - value = self.data.get(msg["name"]) + value = self.data.get(msg.get("name")) else: value = self.substitute_fn(msg) diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 96211f1f2..7e9301084 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -656,13 +656,20 @@ def initialize_model( # substitute param sites from model_trace to model so # we don't need to generate again parameters of `numpyro.module` - def substitute_fn(site): - if site["type"] in ["param", "mutable"]: - return site["value"] - elif site["type"] == "prng_key": + model = substitute( + model, + data={ + k: site["value"] + for k, site in model_trace.items() + if site["type"] in ["param", "mutable"] + }, + ) + + def substitute_key(msg): + if msg["type"] == "prng_key": return random.PRNGKey(0) - model = substitute(model, substitute_fn=substitute_fn) + model = substitute(model, substitute_fn=substitute_key) constrained_values = { k: v["value"] @@ -702,7 +709,14 @@ def substitute_fn(site): prototype_params = transform_fn(inv_transforms, constrained_values, invert=True) (init_params, pe, grad), is_valid = find_valid_initial_params( rng_key, - substitute(model, substitute_fn=substitute_fn), + substitute( + model, + data={ + k: site["value"] + for k, site in model_trace.items() + if site["type"] in ["plate"] + }, + ), init_strategy=init_strategy, enum=has_enumerate_support, model_args=model_args, From 9a6b2081ad694106ba1bcfa0ca869415a77d3f77 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Mon, 1 Jul 2024 16:37:15 +0200 Subject: [PATCH 5/7] default handler --- numpyro/infer/util.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index 7e9301084..ae85d6b05 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -21,6 +21,7 @@ from numpyro.distributions.util import is_identically_one, sum_rightmost from numpyro.handlers import condition, replay, seed, substitute, trace from numpyro.infer.initialization import init_to_uniform, init_to_value +from numpyro.primitives import Messenger from numpyro.util import ( _validate_model, find_stack_level, @@ -653,7 +654,6 @@ def initialize_model( has_enumerate_support, model_trace, ) = _get_model_transforms(substituted_model, model_args, model_kwargs) - # substitute param sites from model_trace to model so # we don't need to generate again parameters of `numpyro.module` model = substitute( @@ -665,11 +665,12 @@ def initialize_model( }, ) - def substitute_key(msg): - if msg["type"] == "prng_key": - return random.PRNGKey(0) + class _substitute_default_key(Messenger): + def process_message(self, msg): + if msg["type"] == "prng_key" and msg["value"] is None: + msg["value"] = random.PRNGKey(0) - model = substitute(model, substitute_fn=substitute_key) + model = _substitute_default_key(model) constrained_values = { k: v["value"] From 65967d3095281892460b538b056c95588af1b7d9 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Mon, 1 Jul 2024 16:44:43 +0200 Subject: [PATCH 6/7] rm prng_key from substitute --- numpyro/handlers.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/numpyro/handlers.py b/numpyro/handlers.py index 2e9481c07..11fdde45c 100644 --- a/numpyro/handlers.py +++ b/numpyro/handlers.py @@ -805,8 +805,7 @@ def __init__(self, fn=None, data=None, substitute_fn=None): def process_message(self, msg): if ( - msg["type"] - not in ("sample", "param", "mutable", "plate", "deterministic", "prng_key") + msg["type"] not in ("sample", "param", "mutable", "plate", "deterministic") ) or msg.get("_control_flow_done", False): if msg["type"] == "control_flow": if self.data is not None: From 1625111576d4672298f60acaa2d76a4cf86e7d70 Mon Sep 17 00:00:00 2001 From: juanitorduz Date: Mon, 1 Jul 2024 17:29:32 +0200 Subject: [PATCH 7/7] remove class from function --- numpyro/infer/util.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/numpyro/infer/util.py b/numpyro/infer/util.py index ae85d6b05..56c4d2c4d 100644 --- a/numpyro/infer/util.py +++ b/numpyro/infer/util.py @@ -47,6 +47,12 @@ ParamInfo = namedtuple("ParamInfo", ["z", "potential_energy", "z_grad"]) +class _substitute_default_key(Messenger): + def process_message(self, msg): + if msg["type"] == "prng_key" and msg["value"] is None: + msg["value"] = random.PRNGKey(0) + + def log_density(model, model_args, model_kwargs, params): """ (EXPERIMENTAL INTERFACE) Computes log of joint density for the model given @@ -665,11 +671,6 @@ def initialize_model( }, ) - class _substitute_default_key(Messenger): - def process_message(self, msg): - if msg["type"] == "prng_key" and msg["value"] is None: - msg["value"] = random.PRNGKey(0) - model = _substitute_default_key(model) constrained_values = {