Skip to content

Commit

Permalink
Show pickling issues in notebook on windows (#3991)
Browse files Browse the repository at this point in the history
* Merge close remote connection

* Manually pickle step method in multiprocess sampling

* Fix tests for extra divergence info

* Add test for remote process crash

* Better formatting in test_parallel_sampling

Co-authored-by: Junpeng Lao <[email protected]>

* Use mp_ctx forkserver on MacOS

* Add test for pickle with dill

Co-authored-by: Junpeng Lao <[email protected]>
  • Loading branch information
aseyboldt and junpenglao authored Jul 7, 2020
1 parent 77873e9 commit 90f48ed
Show file tree
Hide file tree
Showing 6 changed files with 254 additions and 83 deletions.
5 changes: 5 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# Release Notes

## PyMC3 3.9.x (on deck)

### Maintenance
- Fix an error on Windows and Mac where error message from unpickling models did not show up in the notebook, or where sampling froze when a worker process crashed (see [#3991](https://github.com/pymc-devs/pymc3/pull/3991)).

### Documentation
- Notebook on [multilevel modeling](https://docs.pymc.io/notebooks/multilevel_modeling.html) has been rewritten to showcase ArviZ and xarray usage for inference result analysis (see [#3963](https://github.com/pymc-devs/pymc3/pull/3963))

### New features
- Introduce optional arguments to `pm.sample`: `mp_ctx` to control how the processes for parallel sampling are started, and `pickle_backend` to specify which library is used to pickle models in parallel sampling when the multiprocessing cnotext is not of type `fork`. (see [#3991](https://github.com/pymc-devs/pymc3/pull/3991))
- Add sampler stats `process_time_diff`, `perf_counter_diff` and `perf_counter_start`, that record wall and CPU times for each NUTS and HMC sample (see [ #3986](https://github.com/pymc-devs/pymc3/pull/3986)).

## PyMC3 3.9.2 (24 June 2020)
Expand Down
6 changes: 0 additions & 6 deletions pymc3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,6 @@
handler = logging.StreamHandler()
_log.addHandler(handler)

# Set start method to forkserver for MacOS to enable multiprocessing
# Closes issue https://github.com/pymc-devs/pymc3/issues/3849
sys = platform.system()
if sys == "Darwin":
new_context = mp.get_context("forkserver")


def __set_compiler_flags():
# Workarounds for Theano compiler problems on various platforms
Expand Down
220 changes: 150 additions & 70 deletions pymc3/parallel_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
import ctypes
import time
import logging
import pickle
from collections import namedtuple
import traceback
import platform
from pymc3.exceptions import SamplingError
import errno

import numpy as np
from fastprogress.fastprogress import progress_bar
Expand All @@ -30,37 +31,6 @@
logger = logging.getLogger("pymc3")


def _get_broken_pipe_exception():
import sys

if sys.platform == "win32":
return RuntimeError(
"The communication pipe between the main process "
"and its spawned children is broken.\n"
"In Windows OS, this usually means that the child "
"process raised an exception while it was being "
"spawned, before it was setup to communicate to "
"the main process.\n"
"The exceptions raised by the child process while "
"spawning cannot be caught or handled from the "
"main process, and when running from an IPython or "
"jupyter notebook interactive kernel, the child's "
"exception and traceback appears to be lost.\n"
"A known way to see the child's error, and try to "
"fix or handle it, is to run the problematic code "
"as a batch script from a system's Command Prompt. "
"The child's exception will be printed to the "
"Command Promt's stderr, and it should be visible "
"above this error and traceback.\n"
"Note that if running a jupyter notebook that was "
"invoked from a Command Prompt, the child's "
"exception should have been printed to the Command "
"Prompt on which the notebook is running."
)
else:
return None


class ParallelSamplingError(Exception):
def __init__(self, message, chain, warnings=None):
super().__init__(message)
Expand Down Expand Up @@ -104,26 +74,65 @@ def rebuild_exc(exc, tb):
# ('start',)


class _Process(multiprocessing.Process):
class _Process:
"""Seperate process for each chain.
We communicate with the main process using a pipe,
and send finished samples using shared memory.
"""

def __init__(self, name:str, msg_pipe, step_method, shared_point, draws:int, tune:int, seed):
super().__init__(daemon=True, name=name)
def __init__(
self,
name: str,
msg_pipe,
step_method,
step_method_is_pickled,
shared_point,
draws: int,
tune: int,
seed,
pickle_backend,
):
self._msg_pipe = msg_pipe
self._step_method = step_method
self._step_method_is_pickled = step_method_is_pickled
self._shared_point = shared_point
self._seed = seed
self._tt_seed = seed + 1
self._draws = draws
self._tune = tune
self._pickle_backend = pickle_backend

def _unpickle_step_method(self):
unpickle_error = (
"The model could not be unpickled. This is required for sampling "
"with more than one core and multiprocessing context spawn "
"or forkserver."
)
if self._step_method_is_pickled:
if self._pickle_backend == 'pickle':
try:
self._step_method = pickle.loads(self._step_method)
except Exception:
raise ValueError(unpickle_error)
elif self._pickle_backend == 'dill':
try:
import dill
except ImportError:
raise ValueError(
"dill must be installed for pickle_backend='dill'."
)
try:
self._step_method = dill.loads(self._step_method)
except Exception:
raise ValueError(unpickle_error)
else:
raise ValueError("Unknown pickle backend")

def run(self):
try:
# We do not create this in __init__, as pickling this
# would destroy the shared memory.
self._unpickle_step_method()
self._point = self._make_numpy_refs()
self._start_loop()
except KeyboardInterrupt:
Expand Down Expand Up @@ -219,10 +228,25 @@ def _collect_warnings(self):
return []


def _run_process(*args):
_Process(*args).run()


class ProcessAdapter:
"""Control a Chain process from the main thread."""

def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start):
def __init__(
self,
draws: int,
tune: int,
step_method,
step_method_pickled,
chain: int,
seed,
start,
mp_ctx,
pickle_backend,
):
self.chain = chain
process_name = "worker_chain_%s" % chain
self._msg_pipe, remote_conn = multiprocessing.Pipe()
Expand All @@ -237,7 +261,7 @@ def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start):
if size != ctypes.c_size_t(size).value:
raise ValueError("Variable %s is too large" % name)

array = multiprocessing.sharedctypes.RawArray("c", size)
array = mp_ctx.RawArray("c", size)
self._shared_point[name] = array
array_np = np.frombuffer(array, dtype).reshape(shape)
array_np[...] = start[name]
Expand All @@ -246,27 +270,31 @@ def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start):
self._readable = True
self._num_samples = 0

self._process = _Process(
process_name,
remote_conn,
step_method,
self._shared_point,
draws,
tune,
seed,
if step_method_pickled is not None:
step_method_send = step_method_pickled
else:
step_method_send = step_method

self._process = mp_ctx.Process(
daemon=True,
name=process_name,
target=_run_process,
args=(
process_name,
remote_conn,
step_method_send,
step_method_pickled is not None,
self._shared_point,
draws,
tune,
seed,
pickle_backend,
)
)
try:
self._process.start()
except IOError as e:
# Something may have gone wrong during the fork / spawn
if e.errno == errno.EPIPE:
exc = _get_broken_pipe_exception()
if exc is not None:
# Sleep a little to give the child process time to flush
# all its error message
time.sleep(0.2)
raise exc
raise
self._process.start()
# Close the remote pipe, so that we get notified if the other
# end is closed.
remote_conn.close()

@property
def shared_point_view(self):
Expand All @@ -277,15 +305,38 @@ def shared_point_view(self):
raise RuntimeError()
return self._point

def _send(self, msg, *args):
try:
self._msg_pipe.send((msg, *args))
except Exception:
# try to recive an error message
message = None
try:
message = self._msg_pipe.recv()
except Exception:
pass
if message is not None and message[0] == "error":
warns, old_error = message[1:]
if warns is not None:
error = ParallelSamplingError(
str(old_error),
self.chain,
warns
)
else:
error = RuntimeError("Chain %s failed." % self.chain)
raise error from old_error
raise

def start(self):
self._msg_pipe.send(("start",))
self._send("start")

def write_next(self):
self._readable = False
self._msg_pipe.send(("write_next",))
self._send("write_next")

def abort(self):
self._msg_pipe.send(("abort",))
self._send("abort")

def join(self, timeout=None):
self._process.join(timeout)
Expand Down Expand Up @@ -324,7 +375,7 @@ def terminate_all(processes, patience=2):
for process in processes:
try:
process.abort()
except EOFError:
except Exception:
pass

start_time = time.time()
Expand Down Expand Up @@ -353,23 +404,52 @@ def terminate_all(processes, patience=2):
class ParallelSampler:
def __init__(
self,
draws:int,
tune:int,
chains:int,
cores:int,
seeds:list,
start_points:list,
draws: int,
tune: int,
chains: int,
cores: int,
seeds: list,
start_points: list,
step_method,
start_chain_num:int=0,
progressbar:bool=True,
start_chain_num: int = 0,
progressbar: bool = True,
mp_ctx=None,
pickle_backend: str = 'pickle',
):

if any(len(arg) != chains for arg in [seeds, start_points]):
raise ValueError("Number of seeds and start_points must be %s." % chains)

if mp_ctx is None or isinstance(mp_ctx, str):
# Closes issue https://github.com/pymc-devs/pymc3/issues/3849
if platform.system() == 'Darwin':
mp_ctx = "forkserver"
mp_ctx = multiprocessing.get_context(mp_ctx)

step_method_pickled = None
if mp_ctx.get_start_method() != 'fork':
if pickle_backend == 'pickle':
step_method_pickled = pickle.dumps(step_method, protocol=-1)
elif pickle_backend == 'dill':
try:
import dill
except ImportError:
raise ValueError(
"dill must be installed for pickle_backend='dill'."
)
step_method_pickled = dill.dumps(step_method, protocol=-1)

self._samplers = [
ProcessAdapter(
draws, tune, step_method, chain + start_chain_num, seed, start
draws,
tune,
step_method,
step_method_pickled,
chain + start_chain_num,
seed,
start,
mp_ctx,
pickle_backend
)
for chain, seed, start in zip(range(chains), seeds, start_points)
]
Expand Down
Loading

0 comments on commit 90f48ed

Please sign in to comment.