Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

SMC: refactor, speed-up and run multiple chains in parallel for diagnostics #3981

Merged
merged 13 commits into from
Jun 29, 2020
173 changes: 112 additions & 61 deletions docs/source/notebooks/Bayes_factor.ipynb

Large diffs are not rendered by default.

355 changes: 282 additions & 73 deletions docs/source/notebooks/SMC2_gaussians.ipynb

Large diffs are not rendered by default.

160 changes: 124 additions & 36 deletions pymc3/smc/sample_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,25 +14,37 @@

import time
import logging
import warnings
from collections.abc import Iterable
import multiprocessing as mp
import numpy as np

from .smc import SMC
from ..model import modelcontext
from ..backends.base import MultiTrace
from ..parallel_sampling import _cpu_count

EXPERIMENTAL_WARNING = (
"Warning: SMC-ABC is an experimental step method and not yet recommended for use in PyMC3!"
)


def sample_smc(
draws=1000,
draws=2000,
kernel="metropolis",
n_steps=25,
parallel=False,
start=None,
cores=None,
tune_steps=True,
p_acc_rate=0.99,
threshold=0.5,
epsilon=1.0,
dist_func="gaussian_kernel",
sum_stat="identity",
progressbar=False,
model=None,
random_seed=-1,
parallel=False,
chains=None,
cores=None,
):
r"""
Sequential Monte Carlo based sampling
Expand All @@ -49,15 +61,9 @@ def sample_smc(
The number of steps of each Markov Chain. If ``tune_steps == True`` ``n_steps`` will be used
for the first stage and for the others it will be determined automatically based on the
acceptance rate and `p_acc_rate`, the max number of steps is ``n_steps``.
parallel: bool
Distribute computations across cores if the number of cores is larger than 1.
Defaults to False.
start: dict, or array of dict
Starting point in parameter space. It should be a list of dict with length `chains`.
When None (default) the starting point is sampled from the prior distribution.
cores: int
The number of chains to run in parallel. If ``None`` (default), it will be automatically
set to the number of CPUs in the system.
tune_steps: bool
Whether to compute the number of steps automatically or not. Defaults to True
p_acc_rate: float
Expand All @@ -75,11 +81,19 @@ def sample_smc(
sum_stat: str or callable
Summary statistics. Available options are ``indentity``, ``sorted``, ``mean``, ``median``.
If a callable is based it should return a number or a 1d numpy array.
progressbar: bool
Flag for displaying a progress bar. Defaults to False.
model: Model (optional if in ``with`` context)).
random_seed: int
random seed
parallel: bool
Distribute computations across cores if the number of cores is larger than 1.
Defaults to False.
cores : int
The number of chains to run in parallel. If ``None``, set to the number of CPUs in the
system, but at most 4.
chains : int
The number of chains to sample. Running independent chains is important for some
convergence statistics. If ``None`` (default), then set to either ``cores`` or 2, whichever
is larger.

Notes
-----
Expand Down Expand Up @@ -126,52 +140,126 @@ def sample_smc(
%282007%29133:7%28816%29>`__
"""

_log = logging.getLogger("pymc3")
_log.info("Initializing SMC sampler...")

if cores is None:
cores = _cpu_count()

if chains is None:
chains = max(2, cores)

_log.info(f"Multiprocess sampling ({chains} chains in {cores} jobs)")

if random_seed == -1:
random_seed = None
if chains == 1 and isinstance(random_seed, int):
random_seed = [random_seed]
if random_seed is None or isinstance(random_seed, int):
if random_seed is not None:
np.random.seed(random_seed)
random_seed = [np.random.randint(2 ** 30) for _ in range(chains)]
if not isinstance(random_seed, Iterable):
raise TypeError("Invalid value for `random_seed`. Must be tuple, list or int")

if kernel.lower() == "abc":
warnings.warn(EXPERIMENTAL_WARNING)
if len(modelcontext(model).observed_RVs) != 1:
warnings.warn("SMC-ABC only works properly with models with one observed variable")

params = (
draws,
kernel,
n_steps,
start,
tune_steps,
p_acc_rate,
threshold,
epsilon,
dist_func,
sum_stat,
model,
)

t1 = time.time()
if parallel:
loggers = [_log] + [None] * (chains - 1)
pool = mp.Pool(cores)
results = pool.starmap(
sample_smc_int, [(*params, random_seed[i], i, loggers[i]) for i in range(chains)]
)

pool.close()
pool.join()
else:
results = []
for i in range(chains):
results.append((sample_smc_int(*params, random_seed[i], i, _log)))

traces, log_marginal_likelihoods, betas, accept_ratios, nsteps = zip(*results)
trace = MultiTrace(traces)
trace.report._n_draws = draws
trace.report._n_tune = 0
trace.report._t_sampling = time.time() - t1
trace.report.log_marginal_likelihood = np.array(log_marginal_likelihoods)
trace.report.betas = betas
trace.report.accept_ratios = accept_ratios
trace.report.nsteps = nsteps

return trace


def sample_smc_int(
draws,
kernel,
n_steps,
start,
tune_steps,
p_acc_rate,
threshold,
epsilon,
dist_func,
sum_stat,
model,
random_seed,
chain,
_log,
):

smc = SMC(
draws=draws,
kernel=kernel,
n_steps=n_steps,
parallel=parallel,
start=start,
cores=cores,
tune_steps=tune_steps,
p_acc_rate=p_acc_rate,
threshold=threshold,
epsilon=epsilon,
dist_func=dist_func,
sum_stat=sum_stat,
progressbar=progressbar,
model=model,
random_seed=random_seed,
chain=chain,
)

t1 = time.time()
_log = logging.getLogger("pymc3")
_log.info("Sample initial stage: ...")
stage = 0
betas = []
accept_ratios = []
nsteps = []
smc.initialize_population()
smc.setup_kernel()
smc.initialize_logp()

while smc.beta < 1:
smc.update_weights_beta()
_log.info(
"Stage: {:3d} Beta: {:.3f} Steps: {:3d} Acce: {:.3f}".format(
stage, smc.beta, smc.n_steps, smc.acc_rate
)
)
smc.resample()
if _log is not None:
_log.info(f"Stage: {stage:3d} Beta: {smc.beta:.3f}")
smc.update_proposal()
if stage > 0:
smc.tune()
smc.resample()
smc.mutate()
smc.tune()
stage += 1
betas.append(smc.beta)
accept_ratios.append(smc.acc_rate)
nsteps.append(smc.n_steps)

if smc.parallel and smc.cores > 1:
smc.pool.close()
smc.pool.join()

trace = smc.posterior_to_trace()
trace.report._n_draws = smc.draws
trace.report._n_tune = 0
trace.report._t_sampling = time.time() - t1
return trace
return smc.posterior_to_trace(), smc.log_marginal_likelihood, betas, accept_ratios, nsteps
Loading