Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support RandomState Variable types in Scan #738

Open
Tracked by #1426
brandonwillard opened this issue Jan 10, 2022 · 4 comments
Open
Tracked by #1426

Support RandomState Variable types in Scan #738

brandonwillard opened this issue Jan 10, 2022 · 4 comments
Labels
enhancement New feature or request help wanted Extra attention is needed important random variables Involves random variables and/or sampling refactor This issue involves refactoring request discussion Scan Involves the `Scan` `Op`

Comments

@brandonwillard
Copy link
Member

brandonwillard commented Jan 10, 2022

Currently, it's not possible to manually manage the RNG state evolution of RandomVariables in Scan, and it's due to limited Variable type support.

Fixing this would take us a large step toward removing the need to use shared variables and updates when making combined use of Scans and RandomVariables.

For example,

import aesara
import aesara.tensor as at

from aesara.tensor.random.type import random_generator_type


def inner_fn(prev_rng):
    res = at.random.normal(rng=prev_rng)
    new_rng = res.owner.outputs[0]
    return new_rng


initial_rng = random_generator_type()

out, _ = aesara.scan(
    inner_fn, outputs_info={"initial": initial_rng, "taps": [-1]}, n_steps=10
)
# TypeError: Tensor type field must be a TensorType; found <class 'aesara.tensor.random.type.RandomGeneratorType'>.

Since it won't be easy to return a collection of all intermediate RNG state variables, we could instead return just the last one (if any). These are the kinds of things that need to be determined in order to implement this.

It's also possible that we could add special handling that reduces the boilerplate. For example, we could make it possible to omit the explicit RNG return statements, so that the following is sufficient to produce distinct samples on each iteration:

def inner_fn(prev_rng):
    res = at.random.normal(rng=prev_rng)
    return res


initial_rng = random_generator_type()

out, _ = aesara.scan(
    inner_fn, non_sequences=[initial_rng], n_steps=10
)
@brandonwillard brandonwillard added enhancement New feature or request help wanted Extra attention is needed important refactor This issue involves refactoring Scan Involves the `Scan` `Op` labels Jan 10, 2022
@brandonwillard brandonwillard added random variables Involves random variables and/or sampling request discussion labels Jan 10, 2022
@ricardoV94
Copy link
Contributor

It's also possible that we could add special handling that reduces the boilerplate. For example, we could make it possible to omit the explicit RNG return statements, so that the following is sufficient to produce distinct samples on each iteration:

How would that work behind the scenes?

I don't have a big issue with returning next rngs explicitly. We could also add a helper property to RandomVariables like .next_rng that returns the first output to make it more readable.

@brandonwillard
Copy link
Member Author

brandonwillard commented Jan 11, 2022

It's also possible that we could add special handling that reduces the boilerplate. For example, we could make it possible to omit the explicit RNG return statements, so that the following is sufficient to produce distinct samples on each iteration:

How would that work behind the scenes?

We would construct a Scan that's equivalent to the first example.

@ricardoV94
Copy link
Contributor

ricardoV94 commented Feb 5, 2022

So first step here is to allow Scan to accept RandomGeneratorType variables as inputs?

@brandonwillard
Copy link
Member Author

brandonwillard commented Feb 5, 2022

So first step here is to allow Scan to accept RandomGeneratorType variables as inputs?

Its base class, yes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request help wanted Extra attention is needed important random variables Involves random variables and/or sampling refactor This issue involves refactoring request discussion Scan Involves the `Scan` `Op`
Projects
None yet
Development

No branches or pull requests

2 participants