From 301430f89c23a0de5e7f6064036cc3e68c701364 Mon Sep 17 00:00:00 2001 From: Stephen Hogg Date: Tue, 24 Nov 2020 22:40:45 +1100 Subject: [PATCH 1/8] Re-create branch --- RELEASE-NOTES.md | 6 +++++ pymc3/model.py | 2 +- pymc3/sampling.py | 11 ++++++++ pymc3/tests/test_hmc.py | 15 ----------- pymc3/tests/test_util.py | 34 +++++++++++++++++++++++++ pymc3/tuning/starting.py | 10 ++------ pymc3/util.py | 54 +++++++++++++++++++++++++++++++++++----- 7 files changed, 102 insertions(+), 30 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index 9199eaa931..c1564f99d1 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -1,5 +1,11 @@ # Release Notes +## PyMC3 3.10.0 + +### Maintenance +- Test model logp before starting any MCMC chains (see [#4116](https://github.com/pymc-devs/pymc3/issues/4116)) +- Fix bug in `model.check_test_point` that caused the `test_point` argument to be ignored. (see [PR #4211](https://github.com/pymc-devs/pymc3/pull/4211#issuecomment-727142721)) + ## PyMC3 3.9.x (on deck) ### Maintenance diff --git a/pymc3/model.py b/pymc3/model.py index 126c753742..b3edfdd406 100644 --- a/pymc3/model.py +++ b/pymc3/model.py @@ -1365,7 +1365,7 @@ def check_test_point(self, test_point=None, round_vals=2): test_point = self.test_point return Series( - {RV.name: np.round(RV.logp(self.test_point), round_vals) for RV in self.basic_RVs}, + {RV.name: np.round(RV.logp(test_point), round_vals) for RV in self.basic_RVs}, name="Log-probability of test_point", ) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index d6d4099293..0bf5488432 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -53,6 +53,7 @@ arraystep, ) from .util import ( + check_start_vals, update_start_vals, get_untransformed_name, is_transformed_name, @@ -419,7 +420,16 @@ def sample( """ model = modelcontext(model) + if start is None: + start = model.test_point + else: + if isinstance(start, dict): + update_start_vals(start, model.test_point, model) + else: + for chain_start_vals in start: + update_start_vals(chain_start_vals, model.test_point, model) + check_start_vals(start, model) if cores is None: cores = min(4, _cpu_count()) @@ -487,6 +497,7 @@ def sample( progressbar=progressbar, **kwargs, ) + check_start_vals(start_) if start is None: start = start_ except (AttributeError, NotImplementedError, tg.NullTypeGradError): diff --git a/pymc3/tests/test_hmc.py b/pymc3/tests/test_hmc.py index 1dd3e42acb..384501f2fd 100644 --- a/pymc3/tests/test_hmc.py +++ b/pymc3/tests/test_hmc.py @@ -17,9 +17,7 @@ from . import models from pymc3.step_methods.hmc.base_hmc import BaseHMC -from pymc3.exceptions import SamplingError import pymc3 -import pytest import logging from pymc3.theanof import floatX @@ -57,16 +55,3 @@ def test_nuts_tuning(): assert not step.tune assert np.all(trace["step_size"][5:] == trace["step_size"][5]) - - -def test_nuts_error_reporting(caplog): - model = pymc3.Model() - with caplog.at_level(logging.CRITICAL) and pytest.raises(SamplingError): - with model: - pymc3.HalfNormal("a", sigma=1, transform=None, testval=-1) - pymc3.HalfNormal("b", sigma=1, transform=None) - trace = pymc3.sample(init="adapt_diag", chains=1) - assert ( - "Bad initial energy, check any log probabilities that are inf or -inf: a -inf\nb" - in caplog.text - ) diff --git a/pymc3/tests/test_util.py b/pymc3/tests/test_util.py index cf1f632a5a..63fda20457 100644 --- a/pymc3/tests/test_util.py +++ b/pymc3/tests/test_util.py @@ -95,6 +95,40 @@ def test_soft_update_parent(self): assert_almost_equal(start["interv_interval__"], test_point["interv_interval__"]) +class TestCheckStartVals(SeededTest): + def setup_method(self): + super().setup_method() + + def test_valid_start_point(self): + with pm.Model() as model: + a = pm.Uniform("a", lower=0.0, upper=1.0) + b = pm.Uniform("b", lower=2.0, upper=3.0) + + start = {"a": 0.3, "b": 2.1} + pm.util.update_start_vals(start, model.test_point, model) + pm.util.check_start_vals(start, model) + + def test_invalid_start_point(self): + with pm.Model() as model: + a = pm.Uniform("a", lower=0.0, upper=1.0) + b = pm.Uniform("b", lower=2.0, upper=3.0) + + start = {"a": np.nan, "b": np.nan} + pm.util.update_start_vals(start, model.test_point, model) + with pytest.raises(pm.exceptions.SamplingError): + pm.util.check_start_vals(start, model) + + def test_invalid_variable_name(self): + with pm.Model() as model: + a = pm.Uniform("a", lower=0.0, upper=1.0) + b = pm.Uniform("b", lower=2.0, upper=3.0) + + start = {"a": 0.3, "b": 2.1, "c": 1.0} + pm.util.update_start_vals(start, model.test_point, model) + with pytest.raises(KeyError): + pm.util.check_start_vals(start, model) + + class TestExceptions: def test_shape_error(self): with pytest.raises(pm.exceptions.ShapeError) as exinfo: diff --git a/pymc3/tuning/starting.py b/pymc3/tuning/starting.py index db49ef52ae..6c83af33f0 100644 --- a/pymc3/tuning/starting.py +++ b/pymc3/tuning/starting.py @@ -28,7 +28,7 @@ from ..theanof import inputvars import theano.gradient as tg from ..blocking import DictToArrayBijection, ArrayOrdering -from ..util import update_start_vals, get_default_varnames, get_var_name +from ..util import check_start_vals, update_start_vals, get_default_varnames, get_var_name import warnings from inspect import getargspec @@ -89,13 +89,7 @@ def find_MAP( else: update_start_vals(start, model.test_point, model) - if not set(start.keys()).issubset(model.named_vars.keys()): - extra_keys = ", ".join(set(start.keys()) - set(model.named_vars.keys())) - valid_keys = ", ".join(model.named_vars.keys()) - raise KeyError( - "Some start parameters do not appear in the model!\n" - "Valid keys are: {}, but {} was supplied".format(valid_keys, extra_keys) - ) + check_start_vals(start, model) if vars is None: vars = model.cont_vars diff --git a/pymc3/util.py b/pymc3/util.py index 95c530825a..22668f7622 100644 --- a/pymc3/util.py +++ b/pymc3/util.py @@ -16,10 +16,11 @@ import functools from typing import List, Dict, Tuple, Union +import numpy as np import xarray import arviz -from numpy import asscalar, ndarray +from pymc3.exceptions import SamplingError from theano.tensor import TensorVariable @@ -149,7 +150,7 @@ def get_repr_for_variable(variable, formatting="plain"): pass value = variable.eval() if not value.shape or value.shape == (1,): - return asscalar(value) + return np.asscalar(value) return "array" if formatting == "latex": @@ -188,6 +189,47 @@ def update_start_vals(a, b, model): a.update({k: v for k, v in b.items() if k not in a}) +def check_start_vals(start, model): + r"""Check that the starting values for MCMC do not cause the relevant log probability + to evaluate to something invalid (e.g. Inf or NaN) + Parameters + ---------- + start : dict, or array of dict + Starting point in parameter space (or partial point) + Defaults to ``trace.point(-1))`` if there is a trace provided and model.test_point if not + (defaults to empty dict). Initialization methods for NUTS (see ``init`` keyword) can + overwrite the default. + model : Model object + Raises + ______ + KeyError if the parameters provided by `start` do not agree with the parameters contained + within `model` + pymc3.exceptions.SamplingError if the evaluation of the parameters in `start` leads to an + invalid (i.e. non-finite) state + Returns + ------- + None + """ + start_points = [start] if isinstance(start, dict) else start + for elem in start_points: + if not set(elem.keys()).issubset(model.named_vars.keys()): + extra_keys = ", ".join(set(elem.keys()) - set(model.named_vars.keys())) + valid_keys = ", ".join(model.named_vars.keys()) + raise KeyError( + "Some start parameters do not appear in the model!\n" + "Valid keys are: {}, but {} was supplied".format(valid_keys, extra_keys) + ) + + initial_eval = model.check_test_point(test_point=elem) + + if not np.all(np.isfinite(initial_eval)): + raise SamplingError( + "Initial evaluation of model at starting point failed!\n" + "Starting values:\n{}\n\n" + "Initial evaluation results:\n{}".format(elem, str(initial_eval)) + ) + + def get_transformed(z): if hasattr(z, "transformed"): z = z.transformed @@ -214,13 +256,13 @@ def enhanced(*args, **kwargs): # FIXME: this function is poorly named, because it returns a LIST of # points, not a dictionary of points. -def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, ndarray]]: +def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, np.ndarray]]: # grab posterior samples for each variable - _samples: Dict[str, ndarray] = {vn: ds[vn].values for vn in ds.keys()} + _samples: Dict[str, np.ndarray] = {vn: ds[vn].values for vn in ds.keys()} # make dicts - points: List[Dict[str, ndarray]] = [] + points: List[Dict[str, np.ndarray]] = [] vn: str - s: ndarray + s: np.ndarray for c in ds.chain: for d in ds.draw: points.append({vn: s[c, d] for vn, s in _samples.items()}) From ac28054a34cbc9265f2a895edec0edcc8d3189ae Mon Sep 17 00:00:00 2001 From: Stephen Hogg Date: Tue, 24 Nov 2020 22:45:26 +1100 Subject: [PATCH 2/8] Fix merge conflict --- pymc3/util.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/pymc3/util.py b/pymc3/util.py index 22668f7622..8256eda036 100644 --- a/pymc3/util.py +++ b/pymc3/util.py @@ -19,6 +19,7 @@ import numpy as np import xarray import arviz +from numpy import ndarray from pymc3.exceptions import SamplingError from theano.tensor import TensorVariable @@ -150,7 +151,7 @@ def get_repr_for_variable(variable, formatting="plain"): pass value = variable.eval() if not value.shape or value.shape == (1,): - return np.asscalar(value) + return value.item() return "array" if formatting == "latex": From de4b46db30cd0ae0ea5a267a8290a4814d78bab4 Mon Sep 17 00:00:00 2001 From: Stephen Hogg Date: Wed, 25 Nov 2020 19:02:03 +1100 Subject: [PATCH 3/8] bug fix --- pymc3/sampling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 0bf5488432..43de34d258 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -497,7 +497,7 @@ def sample( progressbar=progressbar, **kwargs, ) - check_start_vals(start_) + check_start_vals(start_, model) if start is None: start = start_ except (AttributeError, NotImplementedError, tg.NullTypeGradError): From 64ed067805569d7d117f3556a5c94ed112dcdd62 Mon Sep 17 00:00:00 2001 From: Stephen Hogg Date: Wed, 25 Nov 2020 22:07:23 +1100 Subject: [PATCH 4/8] remove unneeded import --- pymc3/util.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc3/util.py b/pymc3/util.py index 8256eda036..ab1b09b5bb 100644 --- a/pymc3/util.py +++ b/pymc3/util.py @@ -19,7 +19,6 @@ import numpy as np import xarray import arviz -from numpy import ndarray from pymc3.exceptions import SamplingError from theano.tensor import TensorVariable From 9d3e0cebc5bc2b655124b41e468442e6a8cdfcb1 Mon Sep 17 00:00:00 2001 From: StephenHogg Date: Wed, 25 Nov 2020 22:36:47 +1100 Subject: [PATCH 5/8] Update pymc3/util.py as per twiecki Co-authored-by: Thomas Wiecki --- pymc3/util.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pymc3/util.py b/pymc3/util.py index ab1b09b5bb..572ed3eb07 100644 --- a/pymc3/util.py +++ b/pymc3/util.py @@ -192,6 +192,7 @@ def update_start_vals(a, b, model): def check_start_vals(start, model): r"""Check that the starting values for MCMC do not cause the relevant log probability to evaluate to something invalid (e.g. Inf or NaN) + Parameters ---------- start : dict, or array of dict From c9c40ff8608aa68fe7f7d6c8c3f41e7c1bdb9f80 Mon Sep 17 00:00:00 2001 From: Stephen Hogg Date: Fri, 27 Nov 2020 08:53:27 +1100 Subject: [PATCH 6/8] fix test_examples.py --- pymc3/tests/test_examples.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pymc3/tests/test_examples.py b/pymc3/tests/test_examples.py index b9eb850b0d..d31b2bfd3a 100644 --- a/pymc3/tests/test_examples.py +++ b/pymc3/tests/test_examples.py @@ -274,7 +274,7 @@ def build_model(self): # Estimated mean count theta = pm.Uniform("theta", 0, 100) # Poisson likelihood - pm.ZeroInflatedPoisson("y", theta, psi, observed=self.y) + pm.ZeroInflatedPoisson("y", psi, theta, observed=self.y) return model def test_run(self): From 2ec02a34b8afc65e653c760093cb132f5b25faf4 Mon Sep 17 00:00:00 2001 From: Stephen Hogg Date: Fri, 27 Nov 2020 09:36:59 +1100 Subject: [PATCH 7/8] fix test_step.py --- pymc3/tests/test_step.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pymc3/tests/test_step.py b/pymc3/tests/test_step.py index 61ba1b7f65..7d42e71445 100644 --- a/pymc3/tests/test_step.py +++ b/pymc3/tests/test_step.py @@ -963,15 +963,15 @@ def test_bad_init_nonparallel(self): HalfNormal("a", sigma=1, testval=-1, transform=None) with pytest.raises(SamplingError) as error: sample(init=None, chains=1, random_seed=1) - error.match("Bad initial") + error.match("Initial evaluation") @pytest.mark.skipif(sys.version_info < (3, 6), reason="requires python3.6 or higher") def test_bad_init_parallel(self): with Model(): HalfNormal("a", sigma=1, testval=-1, transform=None) - with pytest.raises(ParallelSamplingError) as error: + with pytest.raises(SamplingError) as error: sample(init=None, cores=2, random_seed=1) - error.match("Bad initial") + error.match("Initial evaluation") def test_linalg(self, caplog): with Model(): From ddf9fc30cf71d66e1b71549ff4a0814aeb10c447 Mon Sep 17 00:00:00 2001 From: Stephen Hogg Date: Fri, 27 Nov 2020 09:49:02 +1100 Subject: [PATCH 8/8] remove unnecessary import --- pymc3/tests/test_step.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pymc3/tests/test_step.py b/pymc3/tests/test_step.py index 7d42e71445..e115bdcb17 100644 --- a/pymc3/tests/test_step.py +++ b/pymc3/tests/test_step.py @@ -27,7 +27,6 @@ simple_2model_continuous, ) from pymc3.sampling import assign_step_methods, sample -from pymc3.parallel_sampling import ParallelSamplingError from pymc3.exceptions import SamplingError from pymc3.model import Model, Potential, set_data