Skip to content

Commit

Permalink
- Fix regression caused by #4211 (#4285)
Browse files Browse the repository at this point in the history
* - Fix regression caused by #4211

* - Add test to make sure jitter is being applied to chains starting points by default

* - Import appropriate empty context for python < 3.7

* - Apply black formatting

* - Change the second check_start_vals to explicitly run on the newly assigned start variable.

* - Improve test documentation and add a new condition

* Use monkeypatch for more robust test

* - Black formatting, once again...
  • Loading branch information
ricardoV94 authored Dec 5, 2020
1 parent 198d13e commit 9311899
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
6 changes: 3 additions & 3 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,15 +416,15 @@ def sample(
"""
model = modelcontext(model)
if start is None:
start = model.test_point
check_start_vals(model.test_point, model)
else:
if isinstance(start, dict):
update_start_vals(start, model.test_point, model)
else:
for chain_start_vals in start:
update_start_vals(chain_start_vals, model.test_point, model)
check_start_vals(start, model)

check_start_vals(start, model)
if cores is None:
cores = min(4, _cpu_count())

Expand Down Expand Up @@ -492,9 +492,9 @@ def sample(
progressbar=progressbar,
**kwargs,
)
check_start_vals(start_, model)
if start is None:
start = start_
check_start_vals(start, model)
except (AttributeError, NotImplementedError, tg.NullTypeGradError):
# gradient computation failed
_log.info("Initializing NUTS failed. " "Falling back to elementwise auto-assignment.")
Expand Down
32 changes: 31 additions & 1 deletion pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from contextlib import ExitStack as does_not_raise
from itertools import combinations
from typing import Tuple
import numpy as np
Expand All @@ -25,7 +26,7 @@
import theano
from pymc3.tests.models import simple_init
from pymc3.tests.helpers import SeededTest
from pymc3.exceptions import IncorrectArgumentsError
from pymc3.exceptions import IncorrectArgumentsError, SamplingError
from scipy import stats
import pytest

Expand Down Expand Up @@ -785,6 +786,35 @@ def test_exec_nuts_init(method):
assert "a" in start[0] and "b_log__" in start[0]


@pytest.mark.parametrize(
"init, start, expectation",
[
("auto", None, pytest.raises(SamplingError)),
("jitter+adapt_diag", None, pytest.raises(SamplingError)),
("auto", {"x": 0}, does_not_raise()),
("jitter+adapt_diag", {"x": 0}, does_not_raise()),
("adapt_diag", None, does_not_raise()),
],
)
def test_default_sample_nuts_jitter(init, start, expectation, monkeypatch):
# This test tries to check whether the starting points returned by init_nuts are actually
# being used when pm.sample() is called without specifying an explicit start point (see
# https://github.com/pymc-devs/pymc3/pull/4285).
def _mocked_init_nuts(*args, **kwargs):
if init == "adapt_diag":
start_ = [{"x": np.array(0.79788456)}]
else:
start_ = [{"x": np.array(-0.04949886)}]
_, step = pm.init_nuts(*args, **kwargs)
return start_, step

monkeypatch.setattr("pymc3.sampling.init_nuts", _mocked_init_nuts)
with pm.Model() as m:
x = pm.HalfNormal("x", transform=None)
with expectation:
pm.sample(tune=1, draws=0, chains=1, init=init, start=start)


@pytest.fixture(scope="class")
def point_list_arg_bug_fixture() -> Tuple[pm.Model, pm.backends.base.MultiTrace]:
with pm.Model() as pmodel:
Expand Down

0 comments on commit 9311899

Please sign in to comment.