diff --git a/pymc3/parallel_sampling.py b/pymc3/parallel_sampling.py index 3caa4ff543d..29192e38f45 100644 --- a/pymc3/parallel_sampling.py +++ b/pymc3/parallel_sampling.py @@ -257,6 +257,9 @@ def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start): ) try: self._process.start() + # Close the remote pipe, so that we get notified if the other + # end is closed. + remote_conn.close() except IOError as e: # Something may have gone wrong during the fork / spawn if e.errno == errno.EPIPE: @@ -285,7 +288,10 @@ def write_next(self): self._msg_pipe.send(("write_next",)) def abort(self): - self._msg_pipe.send(("abort",)) + try: + self._msg_pipe.send(("abort",)) + except BrokenPipeError: + pass def join(self, timeout=None): self._process.join(timeout) diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index e80599f912d..f2dbf77bbc0 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -11,7 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import ctypes from itertools import combinations import packaging from typing import Tuple @@ -907,6 +907,7 @@ def test_bounded_dist(self): prior_trace = pm.sample_prior_predictive(5) assert prior_trace["x"].shape == (5, 3, 1) + class TestSamplePosteriorPredictive: def test_point_list_arg_bug_fspp(self, point_list_arg_bug_fixture): pmodel, trace = point_list_arg_bug_fixture @@ -953,3 +954,30 @@ def test_sample_from_xarray_posterior_fast(self, point_list_arg_bug_fixture): idat.posterior, var_names=['d'] ) + + +tt_vector = tt.TensorType(theano.config.floatX, [False]) +@theano.as_op([tt_vector], [tt_vector]) +def segfault_on_negative(a): + if np.any(a < 0): + # Segfault + ctypes.string_at(0) + return 2 * np.array(a) + + +class TestIssues: + def test_3988(self): + """ Chain crashing on a child process should raise an error on the parent. """ + with pm.Model() as model: + # the test_val is positive such that it passes testval computations: + x = pm.Normal('x', mu=0.1, shape=2) + pm.Normal('y', mu=segfault_on_negative(x), shape=2) + + # expecting to crash child process when it goes to x < 0 during sampling + with pytest.raises(RuntimeError, match=r"Chain \d failed."): + pm.sample( + step=pm.Metropolis(), + cores=2, chains=2, + tune=100, draws=300, + compute_convergence_checks=False + )