Skip to content

Commit

Permalink
Avoid unclear TypeError when using theano.shared variables as input t…
Browse files Browse the repository at this point in the history
…o distribution parameters (#4445)

* Added default testvalue support for theano.shared

Co-authored-by: Ricardo <[email protected]>
  • Loading branch information
kc611 and ricardoV94 authored Jan 31, 2021
1 parent 2d3ec8f commit 07679ec
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 7 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
- We upgraded to `Theano-PyMC v1.1.2` which [includes bugfixes](https://github.com/pymc-devs/aesara/compare/rel-1.1.0...rel-1.1.2) for warning floods and compiledir locking (see [#4444](https://github.com/pymc-devs/pymc3/pull/4444))
- `Theano-PyMC v1.1.2` also fixed an important issue in `tt.switch` that affected the behavior of several PyMC distributions, including at least the `Bernoulli` and `TruncatedNormal` (see[#4448](https://github.com/pymc-devs/pymc3/pull/4448))
- `math.log1mexp_numpy` no longer raises RuntimeWarning when given very small inputs. These were commonly observed during NUTS sampling (see [#4428](https://github.com/pymc-devs/pymc3/pull/4428)).
- `ScalarSharedVariable` can now be used as an input to other RVs directly (see [#4445](https://github.com/pymc-devs/pymc3/pull/4445)).

## PyMC3 3.11.0 (21 January 2021)

Expand Down
14 changes: 7 additions & 7 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,17 +148,17 @@ def default(self):
def get_test_val(self, val, defaults):
if val is None:
for v in defaults:
if hasattr(self, v) and np.all(np.isfinite(self.getattr_value(v))):
return self.getattr_value(v)
else:
return self.getattr_value(val)

if val is None:
if hasattr(self, v):
attr_val = self.getattr_value(v)
if np.all(np.isfinite(attr_val)):
return attr_val
raise AttributeError(
"%s has no finite default value to use, "
"checked: %s. Pass testval argument or "
"adjust so value is finite." % (self, str(defaults))
)
else:
return self.getattr_value(val)

def getattr_value(self, val):
if isinstance(val, string_types):
Expand All @@ -167,7 +167,7 @@ def getattr_value(self, val):
if isinstance(val, tt.TensorVariable):
return val.tag.test_value

if isinstance(val, tt.sharedvar.TensorSharedVariable):
if isinstance(val, tt.sharedvar.SharedVariable):
return val.get_value()

if isinstance(val, theano_constant):
Expand Down
22 changes: 22 additions & 0 deletions pymc3/tests/test_data_container.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
import pandas as pd
import pytest

from theano import shared

import pymc3 as pm

from pymc3.tests.helpers import SeededTest
Expand Down Expand Up @@ -156,6 +158,26 @@ def test_shared_data_as_rv_input(self):
np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), x.get_value(), atol=1e-1)
np.testing.assert_allclose(np.array([2.0, 4.0, 6.0]), trace["y"].mean(0), atol=1e-1)

def test_shared_scalar_as_rv_input(self):
# See https://github.com/pymc-devs/pymc3/issues/3139
with pm.Model() as m:
shared_var = shared(5.0)
v = pm.Normal("v", mu=shared_var, shape=1)

np.testing.assert_allclose(
v.logp({"v": [5.0]}),
-0.91893853,
rtol=1e-5,
)

shared_var.set_value(10.0)

np.testing.assert_allclose(
v.logp({"v": [10.0]}),
-0.91893853,
rtol=1e-5,
)

def test_creation_of_data_outside_model_context(self):
with pytest.raises((IndexError, TypeError)) as error:
pm.Data("data", [1.1, 2.2, 3.3])
Expand Down

0 comments on commit 07679ec

Please sign in to comment.