Skip to content

Commit

Permalink
Set default updates for all graph RandomVariables in compile_pymc
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Feb 2, 2022
1 parent 4fedb60 commit 42229ec
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 6 deletions.
9 changes: 3 additions & 6 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -961,17 +961,14 @@ def compile_pymc(inputs, outputs, mode=None, **kwargs):
this function is called within a model context and the model `check_bounds` flag
is set to False.
"""

# Avoid circular dependency
from pymc.distributions import NoDistribution

# Set the default update of a NoDistribution RNG so that it is automatically
# Set the default update of RandomVariable's RNG so that it is automatically
# updated after every function call
# TODO: This won't work for variables with InnerGraphs (Scan and OpFromGraph)
output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs]
for rv in (
node
for node in walk_model(output_to_list, walk_past_rvs=True)
if node.owner and isinstance(node.owner.op, NoDistribution)
if node.owner and isinstance(node.owner.op, RandomVariable)
):
rng = rv.owner.inputs[0]
if not hasattr(rng, "default_update"):
Expand Down
8 changes: 8 additions & 0 deletions pymc/tests/test_aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -574,3 +574,11 @@ def test_check_bounds_flag():
m.check_bounds = True
with m:
assert np.all(compile_pymc([], bound)() == -np.inf)


def test_compile_pymc_sets_default_updates():
rng = aesara.shared(np.random.default_rng(0))
x = pm.Normal.dist(rng=rng)
assert x.owner.inputs[0] is rng
f = compile_pymc([], x)
assert not np.isclose(f(), f())

0 comments on commit 42229ec

Please sign in to comment.