diff --git a/pymc3/parallel_sampling.py b/pymc3/parallel_sampling.py index 5aa49071e3..106f6cd55f 100644 --- a/pymc3/parallel_sampling.py +++ b/pymc3/parallel_sampling.py @@ -257,6 +257,10 @@ def __init__( if step_method_pickled is not None: step_method_send = step_method_pickled else: + if mp_ctx.get_start_method() == "spawn": + raise ValueError( + "please provide a pre-pickled step method when multiprocessing start method is 'spawn'" + ) step_method_send = step_method self._process = mp_ctx.Process( diff --git a/pymc3/step_methods/hmc/base_hmc.py b/pymc3/step_methods/hmc/base_hmc.py index 915c6d4b6a..c5e9603a90 100644 --- a/pymc3/step_methods/hmc/base_hmc.py +++ b/pymc3/step_methods/hmc/base_hmc.py @@ -101,8 +101,8 @@ def __init__( # XXX: If the dimensions of these terms change, the step size # dimension-scaling should change as well, no? test_point = self._model.initial_point - continuous_vars = [test_point[v.name] for v in self._model.cont_vars] - size = sum(v.size for v in continuous_vars) + nuts_vars = [test_point[v.name] for v in vars] + size = sum(v.size for v in nuts_vars) self.step_size = step_scale / (size ** 0.25) self.step_adapt = step_sizes.DualAverageAdaptation( diff --git a/pymc3/tests/test_parallel_sampling.py b/pymc3/tests/test_parallel_sampling.py index d58604b93e..8bdc3ca1dc 100644 --- a/pymc3/tests/test_parallel_sampling.py +++ b/pymc3/tests/test_parallel_sampling.py @@ -13,9 +13,11 @@ # limitations under the License. import multiprocessing import os +import platform import aesara import aesara.tensor as at +import cloudpickle import numpy as np import pytest @@ -25,6 +27,8 @@ import pymc3 as pm import pymc3.parallel_sampling as ps +from pymc3.aesaraf import floatX + def test_context(): with pm.Model(): @@ -83,20 +87,27 @@ def test_remote_pipe_closed(): pm.sample(step=step, mp_ctx="spawn", tune=2, draws=2, cores=2, chains=2) -@pytest.mark.xfail( - reason="Possibly the same issue described in https://github.com/pymc-devs/pymc3/pull/4701" -) -def test_abort(): +@pytest.mark.skip(reason="Unclear") +@pytest.mark.parametrize("mp_start_method", ["spawn", "fork"]) +def test_abort(mp_start_method): with pm.Model() as model: a = pm.Normal("a", shape=1) - pm.HalfNormal("b") - step1 = pm.NUTS([a]) - step2 = pm.Metropolis([model["b_log__"]]) + b = pm.HalfNormal("b") + step1 = pm.NUTS([model.rvs_to_values[a]]) + step2 = pm.Metropolis([model.rvs_to_values[b]]) step = pm.CompoundStep([step1, step2]) + # on Windows we cannot fork + if platform.system() == "Windows" and mp_start_method == "fork": + return + if mp_start_method == "spawn": + step_method_pickled = cloudpickle.dumps(step, protocol=-1) + else: + step_method_pickled = None + for abort in [False, True]: - ctx = multiprocessing.get_context() + ctx = multiprocessing.get_context(mp_start_method) proc = ps.ProcessAdapter( 10, 10, @@ -104,8 +115,8 @@ def test_abort(): chain=3, seed=1, mp_ctx=ctx, - start={"a": np.array([1.0]), "b_log__": np.array(2.0)}, - step_method_pickled=None, + start={"a": floatX(np.array([1.0])), "b_log__": floatX(np.array(2.0))}, + step_method_pickled=step_method_pickled, ) proc.start() while True: @@ -118,19 +129,25 @@ def test_abort(): proc.join() -@pytest.mark.xfail( - reason="Possibly the same issue described in https://github.com/pymc-devs/pymc3/pull/4701" -) -def test_explicit_sample(): +@pytest.mark.parametrize("mp_start_method", ["spawn", "fork"]) +def test_explicit_sample(mp_start_method): with pm.Model() as model: a = pm.Normal("a", shape=1) - pm.HalfNormal("b") - step1 = pm.NUTS([a]) - step2 = pm.Metropolis([model["b_log__"]]) + b = pm.HalfNormal("b") + step1 = pm.NUTS([model.rvs_to_values[a]]) + step2 = pm.Metropolis([model.rvs_to_values[b]]) step = pm.CompoundStep([step1, step2]) - ctx = multiprocessing.get_context() + # on Windows we cannot fork + if platform.system() == "Windows" and mp_start_method == "fork": + return + if mp_start_method == "spawn": + step_method_pickled = cloudpickle.dumps(step, protocol=-1) + else: + step_method_pickled = None + + ctx = multiprocessing.get_context(mp_start_method) proc = ps.ProcessAdapter( 10, 10, @@ -138,8 +155,8 @@ def test_explicit_sample(): chain=3, seed=1, mp_ctx=ctx, - start={"a": np.array([1.0]), "b_log__": np.array(2.0)}, - step_method_pickled=None, + start={"a": floatX(np.array([1.0])), "b_log__": floatX(np.array(2.0))}, + step_method_pickled=step_method_pickled, ) proc.start() while True: @@ -153,19 +170,16 @@ def test_explicit_sample(): proc.join() -@pytest.mark.xfail( - reason="Possibly the same issue described in https://github.com/pymc-devs/pymc3/pull/4701" -) def test_iterator(): with pm.Model() as model: a = pm.Normal("a", shape=1) - pm.HalfNormal("b") - step1 = pm.NUTS([a]) - step2 = pm.Metropolis([model["b_log__"]]) + b = pm.HalfNormal("b") + step1 = pm.NUTS([model.rvs_to_values[a]]) + step2 = pm.Metropolis([model.rvs_to_values[b]]) step = pm.CompoundStep([step1, step2]) - start = {"a": np.array([1.0]), "b_log__": np.array(2.0)} + start = {"a": floatX(np.array([1.0])), "b_log__": floatX(np.array(2.0))} sampler = ps.ParallelSampler(10, 10, 3, 2, [2, 3, 4], [start] * 3, step, 0, False) with sampler: for draw in sampler: