Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Test model logp before starting any MCMC chains #4211

Merged
merged 10 commits into from
Nov 27, 2020
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# Release Notes

## PyMC3 3.10.0

### Maintenance
- Test model logp before starting any MCMC chains (see [#4116](https://github.com/pymc-devs/pymc3/issues/4116))
- Fix bug in `model.check_test_point` that caused the `test_point` argument to be ignored. (see [PR #4211](https://github.com/pymc-devs/pymc3/pull/4211#issuecomment-727142721))

## PyMC3 3.9.x (on deck)

### Maintenance
Expand Down
2 changes: 1 addition & 1 deletion pymc3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1365,7 +1365,7 @@ def check_test_point(self, test_point=None, round_vals=2):
test_point = self.test_point

return Series(
{RV.name: np.round(RV.logp(self.test_point), round_vals) for RV in self.basic_RVs},
{RV.name: np.round(RV.logp(test_point), round_vals) for RV in self.basic_RVs},
name="Log-probability of test_point",
)

Expand Down
11 changes: 11 additions & 0 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
arraystep,
)
from .util import (
check_start_vals,
update_start_vals,
get_untransformed_name,
is_transformed_name,
Expand Down Expand Up @@ -419,7 +420,16 @@ def sample(

"""
model = modelcontext(model)
if start is None:
start = model.test_point
else:
if isinstance(start, dict):
update_start_vals(start, model.test_point, model)
else:
for chain_start_vals in start:
update_start_vals(chain_start_vals, model.test_point, model)

check_start_vals(start, model)
if cores is None:
cores = min(4, _cpu_count())

Expand Down Expand Up @@ -487,6 +497,7 @@ def sample(
progressbar=progressbar,
**kwargs,
)
check_start_vals(start_, model)
if start is None:
start = start_
except (AttributeError, NotImplementedError, tg.NullTypeGradError):
Expand Down
15 changes: 0 additions & 15 deletions pymc3/tests/test_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@

from . import models
from pymc3.step_methods.hmc.base_hmc import BaseHMC
from pymc3.exceptions import SamplingError
import pymc3
import pytest
import logging
from pymc3.theanof import floatX

Expand Down Expand Up @@ -57,16 +55,3 @@ def test_nuts_tuning():

assert not step.tune
assert np.all(trace["step_size"][5:] == trace["step_size"][5])


def test_nuts_error_reporting(caplog):
twiecki marked this conversation as resolved.
Show resolved Hide resolved
model = pymc3.Model()
with caplog.at_level(logging.CRITICAL) and pytest.raises(SamplingError):
with model:
pymc3.HalfNormal("a", sigma=1, transform=None, testval=-1)
pymc3.HalfNormal("b", sigma=1, transform=None)
trace = pymc3.sample(init="adapt_diag", chains=1)
assert (
"Bad initial energy, check any log probabilities that are inf or -inf: a -inf\nb"
in caplog.text
)
34 changes: 34 additions & 0 deletions pymc3/tests/test_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,40 @@ def test_soft_update_parent(self):
assert_almost_equal(start["interv_interval__"], test_point["interv_interval__"])


class TestCheckStartVals(SeededTest):
def setup_method(self):
super().setup_method()

def test_valid_start_point(self):
with pm.Model() as model:
a = pm.Uniform("a", lower=0.0, upper=1.0)
b = pm.Uniform("b", lower=2.0, upper=3.0)

start = {"a": 0.3, "b": 2.1}
pm.util.update_start_vals(start, model.test_point, model)
pm.util.check_start_vals(start, model)

def test_invalid_start_point(self):
with pm.Model() as model:
a = pm.Uniform("a", lower=0.0, upper=1.0)
b = pm.Uniform("b", lower=2.0, upper=3.0)

start = {"a": np.nan, "b": np.nan}
pm.util.update_start_vals(start, model.test_point, model)
with pytest.raises(pm.exceptions.SamplingError):
pm.util.check_start_vals(start, model)

def test_invalid_variable_name(self):
with pm.Model() as model:
a = pm.Uniform("a", lower=0.0, upper=1.0)
b = pm.Uniform("b", lower=2.0, upper=3.0)

start = {"a": 0.3, "b": 2.1, "c": 1.0}
pm.util.update_start_vals(start, model.test_point, model)
with pytest.raises(KeyError):
pm.util.check_start_vals(start, model)


class TestExceptions:
def test_shape_error(self):
with pytest.raises(pm.exceptions.ShapeError) as exinfo:
Expand Down
10 changes: 2 additions & 8 deletions pymc3/tuning/starting.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
from ..theanof import inputvars
import theano.gradient as tg
from ..blocking import DictToArrayBijection, ArrayOrdering
from ..util import update_start_vals, get_default_varnames, get_var_name
from ..util import check_start_vals, update_start_vals, get_default_varnames, get_var_name

import warnings
from inspect import getargspec
Expand Down Expand Up @@ -89,13 +89,7 @@ def find_MAP(
else:
update_start_vals(start, model.test_point, model)

if not set(start.keys()).issubset(model.named_vars.keys()):
extra_keys = ", ".join(set(start.keys()) - set(model.named_vars.keys()))
valid_keys = ", ".join(model.named_vars.keys())
raise KeyError(
"Some start parameters do not appear in the model!\n"
"Valid keys are: {}, but {} was supplied".format(valid_keys, extra_keys)
)
check_start_vals(start, model)

if vars is None:
vars = model.cont_vars
Expand Down
54 changes: 48 additions & 6 deletions pymc3/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import functools
from typing import List, Dict, Tuple, Union

import numpy as np
import xarray
import arviz
from numpy import asscalar, ndarray

from pymc3.exceptions import SamplingError
from theano.tensor import TensorVariable


Expand Down Expand Up @@ -149,7 +150,7 @@ def get_repr_for_variable(variable, formatting="plain"):
pass
value = variable.eval()
if not value.shape or value.shape == (1,):
return asscalar(value)
return value.item()
return "array"

if formatting == "latex":
Expand Down Expand Up @@ -188,6 +189,47 @@ def update_start_vals(a, b, model):
a.update({k: v for k, v in b.items() if k not in a})


def check_start_vals(start, model):
r"""Check that the starting values for MCMC do not cause the relevant log probability
to evaluate to something invalid (e.g. Inf or NaN)
Parameters
StephenHogg marked this conversation as resolved.
Show resolved Hide resolved
----------
start : dict, or array of dict
Starting point in parameter space (or partial point)
Defaults to ``trace.point(-1))`` if there is a trace provided and model.test_point if not
(defaults to empty dict). Initialization methods for NUTS (see ``init`` keyword) can
overwrite the default.
model : Model object
Raises
______
KeyError if the parameters provided by `start` do not agree with the parameters contained
within `model`
pymc3.exceptions.SamplingError if the evaluation of the parameters in `start` leads to an
invalid (i.e. non-finite) state
Returns
-------
None
"""
start_points = [start] if isinstance(start, dict) else start
for elem in start_points:
if not set(elem.keys()).issubset(model.named_vars.keys()):
extra_keys = ", ".join(set(elem.keys()) - set(model.named_vars.keys()))
valid_keys = ", ".join(model.named_vars.keys())
raise KeyError(
"Some start parameters do not appear in the model!\n"
"Valid keys are: {}, but {} was supplied".format(valid_keys, extra_keys)
)

initial_eval = model.check_test_point(test_point=elem)

if not np.all(np.isfinite(initial_eval)):
raise SamplingError(
"Initial evaluation of model at starting point failed!\n"
"Starting values:\n{}\n\n"
"Initial evaluation results:\n{}".format(elem, str(initial_eval))
)


def get_transformed(z):
if hasattr(z, "transformed"):
z = z.transformed
Expand All @@ -214,13 +256,13 @@ def enhanced(*args, **kwargs):

# FIXME: this function is poorly named, because it returns a LIST of
# points, not a dictionary of points.
def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, ndarray]]:
def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, np.ndarray]]:
# grab posterior samples for each variable
_samples: Dict[str, ndarray] = {vn: ds[vn].values for vn in ds.keys()}
_samples: Dict[str, np.ndarray] = {vn: ds[vn].values for vn in ds.keys()}
# make dicts
points: List[Dict[str, ndarray]] = []
points: List[Dict[str, np.ndarray]] = []
vn: str
s: ndarray
s: np.ndarray
for c in ds.chain:
for d in ds.draw:
points.append({vn: s[c, d] for vn, s in _samples.items()})
Expand Down