Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Set default updates for all graph RandomVariables in compile_pymc #5442

Merged
merged 1 commit into from
Feb 3, 2022

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Feb 2, 2022

This PR makes compile_pymc set default_updates for any RandomVariables found in the graph. Previously this logic was limited to NoDistribution to accommodate the special case of pm.Simulator.

With this change the following code now produces the "expected" behavior (i.e., different draws of x):

with pm.Model() as m:
  pm.Deterministic("x", pm.Normal.dist())
  prior = pm.sample_prior_predictive()

With this change we may also consider removing this responsibility from pm.Distribution which was doing this for registered (named) variables, which should facilitate #5308

rng = kwargs.pop("rng", None)
if (
rv_out.owner
and isinstance(rv_out.owner.op, RandomVariable)
and isinstance(rng, RandomStateSharedVariable)
and not getattr(rng, "default_update", None)
):
# This tells `aesara.function` that the shared RNG variable
# is mutable, which--in turn--tells the `FunctionGraph`
# `Supervisor` feature to allow in-place updates on the variable.
# Without it, the `RandomVariable`s could not be optimized to allow
# in-place RNG updates, forcing all sample results from compiled
# functions to be the same on repeated evaluations.
new_rng = rv_out.owner.outputs[0]
rv_out.update = (rng, new_rng)
rng.default_update = new_rng


This is likely to be all reevaluated soon, as Aesara is in the process of refining the API for seeding of RandomVariables

@codecov
Copy link

codecov bot commented Feb 2, 2022

Codecov Report

Merging #5442 (42229ec) into main (4fedb60) will decrease coverage by 0.02%.
The diff coverage is n/a.

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #5442      +/-   ##
==========================================
- Coverage   81.41%   81.39%   -0.03%     
==========================================
  Files          82       82              
  Lines       14213    14212       -1     
==========================================
- Hits        11572    11568       -4     
- Misses       2641     2644       +3     
Impacted Files Coverage Δ
pymc/aesaraf.py 90.17% <ø> (-0.03%) ⬇️
pymc/parallel_sampling.py 86.71% <0.00%> (-1.00%) ⬇️

@ricardoV94 ricardoV94 merged commit 25c6772 into pymc-devs:main Feb 3, 2022
@ricardoV94 ricardoV94 deleted the set_default_updates_RVs branch June 6, 2023 03:05
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants