Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to start MCMC sampling from the same warmup state #469

Merged
merged 6 commits into from
Nov 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
69 changes: 48 additions & 21 deletions numpyro/infer/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -668,7 +668,7 @@ class MCMC(object):
def __init__(self,
sampler,
num_warmup,
num_samples,
num_samples=0,
num_chains=1,
constrain_fn=None,
chain_method='parallel',
Expand All @@ -688,6 +688,8 @@ def __init__(self,
self._jit_model_args = jit_model_args
self._states = None
self._states_flat = None
# HMCState returned by last warmup
self._warmup_state = None
self._cache = {}

def _get_cached_fn(self):
Expand Down Expand Up @@ -723,39 +725,59 @@ def _hashable(x):
self._cache[key] = fn
return fn

def _single_chain_mcmc(self, rng_key, init_params, args, kwargs, collect_fields=('z',), collect_warmup=False):
init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
model_args=args, model_kwargs=kwargs)
def _single_chain_mcmc(self, rng_key, init_params, args, kwargs,
collect_fields=('z',), collect_warmup=False, reuse_warmup_state=False):
num_warmup = self.num_warmup
if reuse_warmup_state:
if self._warmup_state is None:
raise ValueError('No `init_state` found; warmup results cannot be reused.')
num_warmup = 0
init_state = self._warmup_state
else:
init_state = self.sampler.init(rng_key, self.num_warmup, init_params,
model_args=args, model_kwargs=kwargs)
if self.constrain_fn is None:
self.constrain_fn = self.sampler.constrain_fn(args, kwargs)
collect_fn = attrgetter(*collect_fields)
lower = 0 if collect_warmup else self.num_warmup
lower = 0 if collect_warmup else num_warmup
diagnostics = lambda x: get_diagnostics_str(x[0]) if rng_key.ndim == 1 else None # noqa: E731
init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,)
states = fori_collect(lower, self.num_warmup + self.num_samples,
self._get_cached_fn(),
init_val,
transform=lambda x: collect_fn(x[0]), # noqa: E731
progbar=self.progress_bar,
progbar_desc=functools.partial(get_progbar_desc_str, self.num_warmup),
diagnostics_fn=diagnostics)
collect_vals = fori_collect(lower, num_warmup + self.num_samples,
self._get_cached_fn(),
init_val,
transform=lambda x: collect_fn(x[0]), # noqa: E731
progbar=self.progress_bar,
return_init_state=True,
progbar_desc=functools.partial(get_progbar_desc_str, num_warmup),
diagnostics_fn=diagnostics)
states, warmup_state = collect_vals
# Get first argument of type `HMCState`
warmup_state = warmup_state[0]
# Note that setting i = 0 may result in recompilation due to python integers having
# weak type
self._warmup_state = warmup_state._replace(i=np.zeros_like(warmup_state.i))
if len(collect_fields) == 1:
states = (states,)
states = dict(zip(collect_fields, states))
# Apply constraints if number of samples is non-zero
if len(tree_flatten(states['z'])[0]) > 0:
site_values = tree_flatten(states['z'])[0]
if len(site_values) > 0 and site_values[0].size > 0:
states['z'] = lax.map(self.constrain_fn, states['z'])
return states

def _single_chain_jit_args(self, init, collect_fields=('z',), collect_warmup=False):
return self._single_chain_mcmc(*init, collect_fields=collect_fields, collect_warmup=collect_warmup)
def _single_chain_jit_args(self, init, collect_fields=('z',), collect_warmup=False, reuse_warmup_state=False):
return self._single_chain_mcmc(*init, collect_fields=collect_fields,
collect_warmup=collect_warmup, reuse_warmup_state=reuse_warmup_state)

def _single_chain_nojit_args(self, init, model_args, model_kwargs, collect_fields=('z',), collect_warmup=False):
def _single_chain_nojit_args(self, init, model_args, model_kwargs, collect_fields=('z',),
collect_warmup=False, reuse_warmup_state=False):
return self._single_chain_mcmc(*init, model_args, model_kwargs,
collect_fields=collect_fields,
collect_warmup=collect_warmup)
collect_warmup=collect_warmup,
reuse_warmup_state=reuse_warmup_state)

def run(self, rng_key, *args, extra_fields=(), collect_warmup=False, init_params=None, **kwargs):
def run(self, rng_key, *args, extra_fields=(), collect_warmup=False,
init_params=None, reuse_warmup_state=False, **kwargs):
"""
Run the MCMC samplers and collect samples.

Expand All @@ -769,6 +791,8 @@ def run(self, rng_key, *args, extra_fields=(), collect_warmup=False, init_params
to `False`.
:param init_params: Initial parameters to begin sampling. The type must be consistent
with the input type to `potential_fn`.
:param bool reuse_warmup_state: If `True`, sampling would make use of the last state and
adaptation parameters from the previous warmup run.
:param kwargs: Keyword arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init`
method. These are typically the keyword arguments needed by the `model`.
"""
Expand All @@ -791,20 +815,23 @@ def run(self, rng_key, *args, extra_fields=(), collect_warmup=False, init_params
assert isinstance(extra_fields, (tuple, list))
collect_fields = tuple(set(('z', 'diverging') + tuple(extra_fields)))
if self.num_chains == 1:
states_flat = self._single_chain_mcmc(rng_key, init_params, args, kwargs, collect_fields, collect_warmup)
states_flat = self._single_chain_mcmc(rng_key, init_params, args, kwargs, collect_fields,
collect_warmup, reuse_warmup_state)
states = tree_map(lambda x: x[np.newaxis, ...], states_flat)
else:
rng_keys = random.split(rng_key, self.num_chains)
if self._jit_model_args:
partial_map_fn = partial(self._single_chain_jit_args,
collect_fields=collect_fields,
collect_warmup=collect_warmup)
collect_warmup=collect_warmup,
reuse_warmup_state=reuse_warmup_state)
else:
partial_map_fn = partial(self._single_chain_nojit_args,
model_args=args,
model_kwargs=kwargs,
collect_fields=collect_fields,
collect_warmup=collect_warmup)
collect_warmup=collect_warmup,
reuse_warmup_state=reuse_warmup_state)
if chain_method == 'sequential':
if self.progress_bar:
map_fn = partial(_laxmap, partial_map_fn)
Expand Down
29 changes: 20 additions & 9 deletions numpyro/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,8 @@ def identity(x):
return x


def fori_collect(lower, upper, body_fun, init_val, transform=identity, progbar=True, **progbar_opts):
def fori_collect(lower, upper, body_fun, init_val, transform=identity,
progbar=True, return_init_state=False, **progbar_opts):
"""
This looping construct works like :func:`~jax.lax.fori_loop` but with the additional
effect of collecting values from the loop body. In addition, this allows for
Expand All @@ -148,6 +149,9 @@ def fori_collect(lower, upper, body_fun, init_val, transform=identity, progbar=T
be any Python collection type containing `np.ndarray` objects.
:param transform: a callable to post-process the values returned by `body_fn`.
:param progbar: whether to post progress bar updates.
:param bool return_init_state: If `True`, the state at iteration `lower-1`,
where the collection begins, is also returned. This has the same type
as `init_val`.
:param `**progbar_opts`: optional additional progress bar arguments. A
`diagnostics_fn` can be supplied which when passed the current value
from `body_fun` returns a string that is used to update the progress
Expand All @@ -156,39 +160,46 @@ def fori_collect(lower, upper, body_fun, init_val, transform=identity, progbar=T
:return: collection with the same type as `init_val` with values
collected along the leading axis of `np.ndarray` objects.
"""
assert lower < upper
assert lower <= upper
init_val_flat, unravel_fn = ravel_pytree(transform(init_val))
ravel_fn = lambda x: ravel_pytree(transform(x))[0] # noqa: E731

if not progbar:
collection = np.zeros((upper - lower,) + init_val_flat.shape)

def _body_fn(i, vals):
val, collection = vals
val, collection, start_state = vals
val = body_fun(val)
i = np.where(i >= lower, i - lower, 0)
start_state = lax.cond(i == lower-1,
start_state, lambda _: val,
start_state, lambda x: x)
collection = ops.index_update(collection, i, ravel_fn(val))
return val, collection
return val, collection, start_state

_, collection = fori_loop(0, upper, _body_fn, (init_val, collection))
_, collection, start_state = fori_loop(0, upper, _body_fn, (init_val, collection, init_val))
else:
diagnostics_fn = progbar_opts.pop('diagnostics_fn', None)
progbar_desc = progbar_opts.pop('progbar_desc', lambda x: '')
collection = []

val = init_val
val, start_state = init_val, init_val
with tqdm.trange(upper) as t:
for i in t:
val = jit(body_fun)(val)
if i >= lower:
if i == lower - 1:
start_state = val
elif i >= lower:
collection.append(jit(ravel_fn)(val))
t.set_description(progbar_desc(i), refresh=False)
if diagnostics_fn:
t.set_postfix_str(diagnostics_fn(val), refresh=False)

collection = np.stack(collection)
collection = np.stack(collection) if len(collection) > 0 else \
np.zeros((upper - lower,) + init_val_flat.shape)

return vmap(unravel_fn)(collection)
unravel_collection = vmap(unravel_fn)(collection)
return (unravel_collection, start_state) if return_init_state else unravel_collection


def copy_docs_from(source_class, full_text=False):
Expand Down
8 changes: 6 additions & 2 deletions test/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,8 +237,11 @@ def model(data):

data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000,))
kernel = NUTS(model=model)
mcmc = MCMC(kernel, num_warmup, num_samples)
# num_samples must be >= 1; otherwise we'll get an error
mcmc = MCMC(kernel, num_warmup)
mcmc.run(random.PRNGKey(2), data)
mcmc.num_samples = num_samples
mcmc.run(random.PRNGKey(2), data, reuse_warmup_state=True)
samples = mcmc.get_samples()
assert_allclose(np.mean(samples['mean']), true_mean, rtol=0.05)
assert_allclose(np.mean(samples['std']), true_std, rtol=0.05)
Expand Down Expand Up @@ -392,9 +395,10 @@ def model(data):
return p_latent

data = dist.Categorical(np.array([0.1, 0.6, 0.3])).sample(random.PRNGKey(1), (2000,))
kernel = NUTS(model, )
kernel = NUTS(model)
mcmc = MCMC(kernel, 2, 5, num_chains=2, chain_method=chain_method, jit_model_args=compile_args)
mcmc.run(random.PRNGKey(0), data)
mcmc.run(random.PRNGKey(1), data, reuse_warmup_state=True)


def test_extra_fields():
Expand Down