Skip to content

Commit

Permalink
Keep broadcasting information in make_shared_replacements
Browse files Browse the repository at this point in the history
It seems like broadcasting information gets lost when applying
`pm.make_shared_replacements`, leading to problems with the metropolis
sampler. Potentially related issues below:
 - pymc-devs#1083
 - pymc-devs#1304
 - pymc-devs#1983

This fix was previously suggested in the following issue:
 - pymc-devs#3337

It could be that further adaptations are necessary as indicated in the
issue. Strangely, this does not seem to lead to problems when using
NUTS.
  • Loading branch information
ExpectationMax authored and michaelosthege committed Mar 10, 2021
1 parent 9218f33 commit 79be613
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 3 deletions.
4 changes: 2 additions & 2 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,8 @@
+ ...

### Maintenance
- ⚠ Our memoization mechanism wasn't robust against hash collisions (#4506), sometimes resulting in incorrect values in, for example, posterior predictives. The `pymc3.memoize` module was removed and replaced with `cachetools`. The `hashable` function and `WithMemoization` class were moved to `pymc3.util`.
- ...
- ⚠ Our memoization mechanism wasn't robust against hash collisions (#4506), sometimes resulting in incorrect values in, for example, posterior predictives. The `pymc3.memoize` module was removed and replaced with `cachetools`. The `hashable` function and `WithMemoization` class were moved to `pymc3.util` (see #4525).
- `pm.make_shared_replacements` now retains broadcasting information which fixes issues with Metropolis samplers (see [#4492](https://github.com/pymc-devs/pymc3/pull/4492)).

## PyMC3 3.11.1 (12 February 2021)

Expand Down
25 changes: 25 additions & 0 deletions pymc3/tests/test_theanof.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,38 @@
import theano
import theano.tensor as tt

import pymc3 as pm

from pymc3.theanof import _conversion_map, take_along_axis
from pymc3.vartypes import int_types

FLOATX = str(theano.config.floatX)
INTX = str(_conversion_map[FLOATX])


class TestBroadcasting:
def test_make_shared_replacements(self):
"""Check if pm.make_shared_replacements preserves broadcasting."""

with pm.Model() as test_model:
test1 = pm.Normal("test1", mu=0.0, sigma=1.0, shape=(1, 10))
test2 = pm.Normal("test2", mu=0.0, sigma=1.0, shape=(10, 1))

# Replace test1 with a shared variable, keep test 2 the same
replacement = pm.make_shared_replacements([test_model.test2], test_model)
assert test_model.test1.broadcastable == replacement[test_model.test1].broadcastable

def test_metropolis_sampling(self):
"""Check if the Metropolis sampler can handle broadcasting."""
with pm.Model() as test_model:
test1 = pm.Normal("test1", mu=0.0, sigma=1.0, shape=(1, 10))
test2 = pm.Normal("test2", mu=test1, sigma=1.0, shape=(10, 10))

step = pm.Metropolis()
# This should fail immediately if broadcasting does not work.
pm.sample(tune=5, draws=7, cores=1, step=step, compute_convergence_checks=False)


def _make_along_axis_idx(arr_shape, indices, axis):
# compute dimensions to iterate over
if str(indices.dtype) not in int_types:
Expand Down
7 changes: 6 additions & 1 deletion pymc3/theanof.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,12 @@ def make_shared_replacements(vars, model):
Dict of variable -> new shared variable
"""
othervars = set(model.vars) - set(vars)
return {var: theano.shared(var.tag.test_value, var.name + "_shared") for var in othervars}
return {
var: theano.shared(
var.tag.test_value, var.name + "_shared", broadcastable=var.broadcastable
)
for var in othervars
}


def join_nonshared_inputs(xs, vars, shared, make_shared=False):
Expand Down

0 comments on commit 79be613

Please sign in to comment.