Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change RandomVariable-in-Scan semantics #898

Open
brandonwillard opened this issue Apr 11, 2022 · 2 comments
Open

Change RandomVariable-in-Scan semantics #898

brandonwillard opened this issue Apr 11, 2022 · 2 comments
Labels
enhancement New feature or request help wanted Extra attention is needed important random variables Involves random variables and/or sampling Scan Involves the `Scan` `Op`

Comments

@brandonwillard
Copy link
Member

brandonwillard commented Apr 11, 2022

I'm proposing that we make all instances of a RandomVariable constructed within the body of a Scan necessarily draw samples per-iteration of said Scan.

Background

The current behavior of a naive loop with a random variable is demonstrated in the following graph:

import aesara
import aesara.tensor as at


def inner_fn():
    return at.random.normal()


out, updates = aesara.scan(inner_fn, n_steps=10)

fn = aesara.function([], out)
fn()
# array([-0.50043598, -0.50043598, -0.50043598, -0.50043598, -0.50043598,
#        -0.50043598, -0.50043598, -0.50043598, -0.50043598, -0.50043598])

The repeated output is the single sample value of at.random.normal(), which is equivalent to the following, more explicit graph:

X = at.random.normal()


def inner_fn(x):
    return x


out, updates = aesara.scan(inner_fn, non_sequences=[X], n_steps=10)

fn = aesara.function([], out)
fn()
# array([1.51209684, 1.51209684, 1.51209684, 1.51209684, 1.51209684,
#        1.51209684, 1.51209684, 1.51209684, 1.51209684, 1.51209684])

My contention is that the former graph should always produce a distinct sample for each iteration performed by the Scan Op. If one desires a repeated value, as is currently produced, then constructing a graph like the latter is the appropriate approach.

In other words, an expression like at.random.normal() always indicates a distinct sample within the body of a Scan, just as it would in a plain Python loop.

As of now, I don't see a reason for preserving the current semantics that map the first example above to the second. This isn't an issue in any other case, because no other Ops are so opaquely dependent on a state object like RandomVariable is, so, to preserve the common semantics of random sampling, Scan should handle this case specifically and guarantee those semantics.

One Possible Solution

One quick way to accomplish this is to change the RandomVariable.inplace attribute of all RandomVariables in an inner-graph to True—regardless of whether or not their RandomType instances are shared.

This somewhat breaks the consistency of the RNG object/RandomType instance "evolution", which allows one to provide an RNG to a RandomVariable and get a new RNG corresponding to the updated RNG state after sampling. The input and output RNG are supposed to be distinct—unless the in-place optimization/rewrite is performed, in which case they are the same RNG state.

If we cloned the initial RNG states in a Scan inner-graph and performed the iterations in-place on the cloned states, then returned those, I believe that the standard RNG "evolution" semantics would be preserved. Even so, these input/output RNG states are essentially hidden from the user-level and almost never used explicitly (e.g. chaining input/output RNGs results in convoluted graphs and many cloned RNG states at runtime), so the utility of fully preserving these semantics is itself questionable.

Also, Scans do not have a means of returning RNG states, so they've never been able to preserve these semantics faithfully (see #738).

Accomplishing the same thing without in-placing by, for example, the addition of an extra tap for the last RNG output by an earlier iteration, so that it can be used in the next, has little to no foreseeable utility—aside from, say, the case in which one desires the RNG state at every iteration in the form of a Scan output. Regardless, the latter case would require #738 and could easily work alongside the proposed in-place changes (e.g. by copying the RNG state when/if it's an explicit output of the inner-graph).

Related Issues:

  • In-place the RandomVariables within a Scan #543
    • If RandomVariables are in-placed within a Scan loop body, they will be updated per-iteration. This would only partially accomplish the goals set out here, but not in a consistent fashion, because the relevant rewrite is only applied in FAST_RUN mode and the rewrite applies to all RandomVariables in a graph—not just the inner-graphs of a Scan.
  • How should/do updates work in nested Scans? #542
    • This issue is particular to nested Scans and shared variable RNG-based updates, and the approach described in this issue obviates the need for shared variables and automatically accounts for nested Scans.
  • Fix some usability issues surrounding shared RNG objects #454
    • Again, this is another issue that centers on shared RNG states and is superseded by this issue, which provides a solution that doesn't depend on shared variables and the updates mechanism.
@brandonwillard brandonwillard added enhancement New feature or request help wanted Extra attention is needed important random variables Involves random variables and/or sampling Scan Involves the `Scan` `Op` labels Apr 11, 2022
@ricardoV94
Copy link
Contributor

Any reason why make only those RV in scan always inplace and not all of them?

@brandonwillard
Copy link
Member Author

brandonwillard commented Apr 12, 2022

All of the RandomVariables in a graph are made in-place (when allowed by DestroyHandler and the like) in FAST_RUN mode; otherwise, in-placing would break the functional (i.e. no side effects) expectations of Aesara graphs and, specifically, the RNG state evolutions performed by RandomVariables.

As I mentioned above, I don't think those are real concerns if in-place updates are used on a copy of the input RNG state(s) and only in the body of a Scan loop. If Scan returned an output RNG state (and it probably should), I believe the end result would be consistent enough with RandomVariable and its state evolutions.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed important random variables Involves random variables and/or sampling Scan Involves the `Scan` `Op`
Projects
None yet
Development

No branches or pull requests

2 participants