Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Show pickling issues in notebook on windows #3991

Merged
merged 7 commits into from
Jul 7, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,15 @@
# Release Notes

## PyMC3 3.9.x (on deck)

### Maintenance
- Fix an error on Windows and Mac where error message from unpickling models did not show up in the notebook, or where sampling froze when a worker process crashed (see [#3991](https://github.com/pymc-devs/pymc3/pull/3991)).

### Documentation
- Notebook on [multilevel modeling](https://docs.pymc.io/notebooks/multilevel_modeling.html) has been rewritten to showcase ArviZ and xarray usage for inference result analysis (see [#3963](https://github.com/pymc-devs/pymc3/pull/3963))

### New features
- Introduce optional arguments to `pm.sample`: `mp_ctx` to control how the processes for parallel sampling are started, and `pickle_backend` to specify which library is used to pickle models in parallel sampling when the multiprocessing cnotext is not of type `fork`. (see [#3991](https://github.com/pymc-devs/pymc3/pull/3991))
- Add sampler stats `process_time_diff`, `perf_counter_diff` and `perf_counter_start`, that record wall and CPU times for each NUTS and HMC sample (see [ #3986](https://github.com/pymc-devs/pymc3/pull/3986)).

## PyMC3 3.9.2 (24 June 2020)
Expand Down
6 changes: 0 additions & 6 deletions pymc3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,12 +27,6 @@
handler = logging.StreamHandler()
_log.addHandler(handler)

# Set start method to forkserver for MacOS to enable multiprocessing
# Closes issue https://github.com/pymc-devs/pymc3/issues/3849
sys = platform.system()
if sys == "Darwin":
new_context = mp.get_context("forkserver")


def __set_compiler_flags():
# Workarounds for Theano compiler problems on various platforms
Expand Down
220 changes: 150 additions & 70 deletions pymc3/parallel_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@
import ctypes
import time
import logging
import pickle
from collections import namedtuple
import traceback
import platform
from pymc3.exceptions import SamplingError
import errno

import numpy as np
from fastprogress.fastprogress import progress_bar
Expand All @@ -30,37 +31,6 @@
logger = logging.getLogger("pymc3")


def _get_broken_pipe_exception():
import sys

if sys.platform == "win32":
return RuntimeError(
"The communication pipe between the main process "
"and its spawned children is broken.\n"
"In Windows OS, this usually means that the child "
"process raised an exception while it was being "
"spawned, before it was setup to communicate to "
"the main process.\n"
"The exceptions raised by the child process while "
"spawning cannot be caught or handled from the "
"main process, and when running from an IPython or "
"jupyter notebook interactive kernel, the child's "
"exception and traceback appears to be lost.\n"
"A known way to see the child's error, and try to "
"fix or handle it, is to run the problematic code "
"as a batch script from a system's Command Prompt. "
"The child's exception will be printed to the "
"Command Promt's stderr, and it should be visible "
"above this error and traceback.\n"
"Note that if running a jupyter notebook that was "
"invoked from a Command Prompt, the child's "
"exception should have been printed to the Command "
"Prompt on which the notebook is running."
)
else:
return None


class ParallelSamplingError(Exception):
def __init__(self, message, chain, warnings=None):
super().__init__(message)
Expand Down Expand Up @@ -104,26 +74,65 @@ def rebuild_exc(exc, tb):
# ('start',)


class _Process(multiprocessing.Process):
class _Process:
"""Seperate process for each chain.
We communicate with the main process using a pipe,
and send finished samples using shared memory.
"""

def __init__(self, name:str, msg_pipe, step_method, shared_point, draws:int, tune:int, seed):
super().__init__(daemon=True, name=name)
def __init__(
self,
name: str,
msg_pipe,
step_method,
step_method_is_pickled,
shared_point,
draws: int,
tune: int,
seed,
pickle_backend,
):
self._msg_pipe = msg_pipe
self._step_method = step_method
self._step_method_is_pickled = step_method_is_pickled
self._shared_point = shared_point
self._seed = seed
self._tt_seed = seed + 1
self._draws = draws
self._tune = tune
self._pickle_backend = pickle_backend

def _unpickle_step_method(self):
unpickle_error = (
"The model could not be unpickled. This is required for sampling "
"with more than one core and multiprocessing context spawn "
"or forkserver."
)
if self._step_method_is_pickled:
if self._pickle_backend == 'pickle':
try:
self._step_method = pickle.loads(self._step_method)
except Exception:
raise ValueError(unpickle_error)
elif self._pickle_backend == 'dill':
try:
import dill
except ImportError:
raise ValueError(
"dill must be installed for pickle_backend='dill'."
)
try:
self._step_method = dill.loads(self._step_method)
except Exception:
raise ValueError(unpickle_error)
else:
raise ValueError("Unknown pickle backend")

def run(self):
try:
# We do not create this in __init__, as pickling this
# would destroy the shared memory.
self._unpickle_step_method()
self._point = self._make_numpy_refs()
self._start_loop()
except KeyboardInterrupt:
Expand Down Expand Up @@ -219,10 +228,25 @@ def _collect_warnings(self):
return []


def _run_process(*args):
_Process(*args).run()


class ProcessAdapter:
"""Control a Chain process from the main thread."""

def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start):
def __init__(
self,
draws: int,
tune: int,
step_method,
step_method_pickled,
chain: int,
seed,
start,
mp_ctx,
pickle_backend,
):
self.chain = chain
process_name = "worker_chain_%s" % chain
self._msg_pipe, remote_conn = multiprocessing.Pipe()
Expand All @@ -237,7 +261,7 @@ def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start):
if size != ctypes.c_size_t(size).value:
raise ValueError("Variable %s is too large" % name)

array = multiprocessing.sharedctypes.RawArray("c", size)
array = mp_ctx.RawArray("c", size)
self._shared_point[name] = array
array_np = np.frombuffer(array, dtype).reshape(shape)
array_np[...] = start[name]
Expand All @@ -246,27 +270,31 @@ def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start):
self._readable = True
self._num_samples = 0

self._process = _Process(
process_name,
remote_conn,
step_method,
self._shared_point,
draws,
tune,
seed,
if step_method_pickled is not None:
step_method_send = step_method_pickled
else:
step_method_send = step_method

self._process = mp_ctx.Process(
daemon=True,
name=process_name,
target=_run_process,
args=(
process_name,
remote_conn,
step_method_send,
step_method_pickled is not None,
self._shared_point,
draws,
tune,
seed,
pickle_backend,
)
)
try:
self._process.start()
except IOError as e:
# Something may have gone wrong during the fork / spawn
if e.errno == errno.EPIPE:
exc = _get_broken_pipe_exception()
if exc is not None:
# Sleep a little to give the child process time to flush
# all its error message
time.sleep(0.2)
raise exc
raise
self._process.start()
# Close the remote pipe, so that we get notified if the other
# end is closed.
remote_conn.close()

@property
def shared_point_view(self):
Expand All @@ -277,15 +305,38 @@ def shared_point_view(self):
raise RuntimeError()
return self._point

def _send(self, msg, *args):
try:
self._msg_pipe.send((msg, *args))
except Exception:
# try to recive an error message
message = None
try:
message = self._msg_pipe.recv()
except Exception:
pass
if message is not None and message[0] == "error":
warns, old_error = message[1:]
if warns is not None:
error = ParallelSamplingError(
str(old_error),
self.chain,
warns
)
else:
error = RuntimeError("Chain %s failed." % self.chain)
raise error from old_error
raise

def start(self):
self._msg_pipe.send(("start",))
self._send("start")

def write_next(self):
self._readable = False
self._msg_pipe.send(("write_next",))
self._send("write_next")

def abort(self):
self._msg_pipe.send(("abort",))
self._send("abort")

def join(self, timeout=None):
self._process.join(timeout)
Expand Down Expand Up @@ -324,7 +375,7 @@ def terminate_all(processes, patience=2):
for process in processes:
try:
process.abort()
except EOFError:
except Exception:
pass

start_time = time.time()
Expand Down Expand Up @@ -353,23 +404,52 @@ def terminate_all(processes, patience=2):
class ParallelSampler:
def __init__(
self,
draws:int,
tune:int,
chains:int,
cores:int,
seeds:list,
start_points:list,
draws: int,
tune: int,
chains: int,
cores: int,
seeds: list,
start_points: list,
step_method,
start_chain_num:int=0,
progressbar:bool=True,
start_chain_num: int = 0,
progressbar: bool = True,
mp_ctx=None,
pickle_backend: str = 'pickle',
):

if any(len(arg) != chains for arg in [seeds, start_points]):
raise ValueError("Number of seeds and start_points must be %s." % chains)

if mp_ctx is None or isinstance(mp_ctx, str):
# Closes issue https://github.com/pymc-devs/pymc3/issues/3849
if platform.system() == 'Darwin':
mp_ctx = "forkserver"
mp_ctx = multiprocessing.get_context(mp_ctx)

step_method_pickled = None
if mp_ctx.get_start_method() != 'fork':
if pickle_backend == 'pickle':
step_method_pickled = pickle.dumps(step_method, protocol=-1)
elif pickle_backend == 'dill':
try:
import dill
except ImportError:
raise ValueError(
"dill must be installed for pickle_backend='dill'."
)
step_method_pickled = dill.dumps(step_method, protocol=-1)

self._samplers = [
ProcessAdapter(
draws, tune, step_method, chain + start_chain_num, seed, start
draws,
tune,
step_method,
step_method_pickled,
chain + start_chain_num,
seed,
start,
mp_ctx,
pickle_backend
)
for chain, seed, start in zip(range(chains), seeds, start_points)
]
Expand Down
Loading