From 42229ece13a313d1f2a133519736106d0718fce8 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Wed, 2 Feb 2022 19:09:25 +0100 Subject: [PATCH] Set default updates for all graph RandomVariables in compile_pymc --- pymc/aesaraf.py | 9 +++------ pymc/tests/test_aesaraf.py | 8 ++++++++ 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index 7c3d9350c8..8616f6ded6 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -961,17 +961,14 @@ def compile_pymc(inputs, outputs, mode=None, **kwargs): this function is called within a model context and the model `check_bounds` flag is set to False. """ - - # Avoid circular dependency - from pymc.distributions import NoDistribution - - # Set the default update of a NoDistribution RNG so that it is automatically + # Set the default update of RandomVariable's RNG so that it is automatically # updated after every function call + # TODO: This won't work for variables with InnerGraphs (Scan and OpFromGraph) output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs] for rv in ( node for node in walk_model(output_to_list, walk_past_rvs=True) - if node.owner and isinstance(node.owner.op, NoDistribution) + if node.owner and isinstance(node.owner.op, RandomVariable) ): rng = rv.owner.inputs[0] if not hasattr(rng, "default_update"): diff --git a/pymc/tests/test_aesaraf.py b/pymc/tests/test_aesaraf.py index 39a6ba4cb9..6f714305b8 100644 --- a/pymc/tests/test_aesaraf.py +++ b/pymc/tests/test_aesaraf.py @@ -574,3 +574,11 @@ def test_check_bounds_flag(): m.check_bounds = True with m: assert np.all(compile_pymc([], bound)() == -np.inf) + + +def test_compile_pymc_sets_default_updates(): + rng = aesara.shared(np.random.default_rng(0)) + x = pm.Normal.dist(rng=rng) + assert x.owner.inputs[0] is rng + f = compile_pymc([], x) + assert not np.isclose(f(), f())