From fc1b06d2614f84744d873770b0427d535e382457 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 27 Nov 2019 14:28:10 -0800 Subject: [PATCH 1/6] Ability to start MCMC sampling from the same warmup state --- numpyro/infer/mcmc.py | 55 +++++++++++++++++++++++++++++-------------- numpyro/util.py | 20 +++++++++++----- test/test_mcmc.py | 7 ++++-- 3 files changed, 56 insertions(+), 26 deletions(-) diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 021f2fac9..1ad835839 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -688,6 +688,7 @@ def __init__(self, self._jit_model_args = jit_model_args self._states = None self._states_flat = None + self._init_state = None self._cache = {} def _get_cached_fn(self): @@ -723,22 +724,32 @@ def _hashable(x): self._cache[key] = fn return fn - def _single_chain_mcmc(self, rng_key, init_params, args, kwargs, collect_fields=('z',), collect_warmup=False): - init_state = self.sampler.init(rng_key, self.num_warmup, init_params, - model_args=args, model_kwargs=kwargs) + def _single_chain_mcmc(self, rng_key, init_params, args, kwargs, + collect_fields=('z',), collect_warmup=False, reuse_warmup=False): + num_warmup = self.num_warmup + if reuse_warmup: + if self._init_state is None: + raise ValueError('No `init_state` found; warmup results cannot be reused.') + num_warmup = 0 + init_state = self._init_state + else: + init_state = self.sampler.init(rng_key, self.num_warmup, init_params, + model_args=args, model_kwargs=kwargs) if self.constrain_fn is None: self.constrain_fn = self.sampler.constrain_fn(args, kwargs) collect_fn = attrgetter(*collect_fields) - lower = 0 if collect_warmup else self.num_warmup + lower = 0 if collect_warmup else num_warmup diagnostics = lambda x: get_diagnostics_str(x[0]) if rng_key.ndim == 1 else None # noqa: E731 init_val = (init_state, args, kwargs) if self._jit_model_args else (init_state,) - states = fori_collect(lower, self.num_warmup + self.num_samples, - self._get_cached_fn(), - init_val, - transform=lambda x: collect_fn(x[0]), # noqa: E731 - progbar=self.progress_bar, - progbar_desc=functools.partial(get_progbar_desc_str, self.num_warmup), - diagnostics_fn=diagnostics) + collect_vals = fori_collect(lower, num_warmup + self.num_samples, + self._get_cached_fn(), + init_val, + transform=lambda x: collect_fn(x[0]), # noqa: E731 + progbar=self.progress_bar, + return_init_state=True, + progbar_desc=functools.partial(get_progbar_desc_str, self.num_warmup), + diagnostics_fn=diagnostics) + states, self._init_state = collect_vals if len(collect_fields) == 1: states = (states,) states = dict(zip(collect_fields, states)) @@ -747,15 +758,19 @@ def _single_chain_mcmc(self, rng_key, init_params, args, kwargs, collect_fields= states['z'] = lax.map(self.constrain_fn, states['z']) return states - def _single_chain_jit_args(self, init, collect_fields=('z',), collect_warmup=False): - return self._single_chain_mcmc(*init, collect_fields=collect_fields, collect_warmup=collect_warmup) + def _single_chain_jit_args(self, init, collect_fields=('z',), collect_warmup=False, reuse_warmup=False): + return self._single_chain_mcmc(*init, collect_fields=collect_fields, + collect_warmup=collect_warmup, reuse_warmup=reuse_warmup) - def _single_chain_nojit_args(self, init, model_args, model_kwargs, collect_fields=('z',), collect_warmup=False): + def _single_chain_nojit_args(self, init, model_args, model_kwargs, collect_fields=('z',), + collect_warmup=False, reuse_warmup=False): return self._single_chain_mcmc(*init, model_args, model_kwargs, collect_fields=collect_fields, - collect_warmup=collect_warmup) + collect_warmup=collect_warmup, + reuse_warmup=reuse_warmup) - def run(self, rng_key, *args, extra_fields=(), collect_warmup=False, init_params=None, **kwargs): + def run(self, rng_key, *args, extra_fields=(), collect_warmup=False, + init_params=None, reuse_warmup=False, **kwargs): """ Run the MCMC samplers and collect samples. @@ -769,6 +784,8 @@ def run(self, rng_key, *args, extra_fields=(), collect_warmup=False, init_params to `False`. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. + :param bool reuse_warmup: If `True`, sampling would make use of the initial state and + adaptation parameters from the last warmup run. :param kwargs: Keyword arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init` method. These are typically the keyword arguments needed by the `model`. """ @@ -798,13 +815,15 @@ def run(self, rng_key, *args, extra_fields=(), collect_warmup=False, init_params if self._jit_model_args: partial_map_fn = partial(self._single_chain_jit_args, collect_fields=collect_fields, - collect_warmup=collect_warmup) + collect_warmup=collect_warmup, + reuse_warmup=reuse_warmup) else: partial_map_fn = partial(self._single_chain_nojit_args, model_args=args, model_kwargs=kwargs, collect_fields=collect_fields, - collect_warmup=collect_warmup) + collect_warmup=collect_warmup, + reuse_warmup=reuse_warmup) if chain_method == 'sequential': if self.progress_bar: map_fn = partial(_laxmap, partial_map_fn) diff --git a/numpyro/util.py b/numpyro/util.py index 083b983cb..55e01e086 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -130,7 +130,8 @@ def identity(x): return x -def fori_collect(lower, upper, body_fun, init_val, transform=identity, progbar=True, **progbar_opts): +def fori_collect(lower, upper, body_fun, init_val, transform=identity, + progbar=True, return_init_state=False, **progbar_opts): """ This looping construct works like :func:`~jax.lax.fori_loop` but with the additional effect of collecting values from the loop body. In addition, this allows for @@ -164,23 +165,29 @@ def fori_collect(lower, upper, body_fun, init_val, transform=identity, progbar=T collection = np.zeros((upper - lower,) + init_val_flat.shape) def _body_fn(i, vals): - val, collection = vals + val, collection, start_state = vals val = body_fun(val) i = np.where(i >= lower, i - lower, 0) + start_state = lax.cond(i == lower-1, + start_state, lambda _: val, + start_state, lambda x: x) collection = ops.index_update(collection, i, ravel_fn(val)) - return val, collection + return val, collection, start_state - _, collection = fori_loop(0, upper, _body_fn, (init_val, collection)) + _, collection, start_state = fori_loop(0, upper, _body_fn, (init_val, collection, init_val)) else: diagnostics_fn = progbar_opts.pop('diagnostics_fn', None) progbar_desc = progbar_opts.pop('progbar_desc', lambda x: '') collection = [] val = init_val + start_state = None with tqdm.trange(upper) as t: for i in t: val = jit(body_fun)(val) - if i >= lower: + if i == lower - 1: + start_state = val + elif i >= lower: collection.append(jit(ravel_fn)(val)) t.set_description(progbar_desc(i), refresh=False) if diagnostics_fn: @@ -188,7 +195,8 @@ def _body_fn(i, vals): collection = np.stack(collection) - return vmap(unravel_fn)(collection) + unravel_collection = vmap(unravel_fn)(collection) + return unravel_collection, start_state if return_init_state else start_state def copy_docs_from(source_class, full_text=False): diff --git a/test/test_mcmc.py b/test/test_mcmc.py index bc26df448..4afc4c438 100644 --- a/test/test_mcmc.py +++ b/test/test_mcmc.py @@ -237,8 +237,11 @@ def model(data): data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000,)) kernel = NUTS(model=model) - mcmc = MCMC(kernel, num_warmup, num_samples) + # num_samples must be >= 1; otherwise we'll get an error + mcmc = MCMC(kernel, num_warmup, num_samples=1, progress_bar=False) mcmc.run(random.PRNGKey(2), data) + mcmc.num_samples = num_samples + mcmc.run(random.PRNGKey(2), data, reuse_warmup=True) samples = mcmc.get_samples() assert_allclose(np.mean(samples['mean']), true_mean, rtol=0.05) assert_allclose(np.mean(samples['std']), true_std, rtol=0.05) @@ -392,7 +395,7 @@ def model(data): return p_latent data = dist.Categorical(np.array([0.1, 0.6, 0.3])).sample(random.PRNGKey(1), (2000,)) - kernel = NUTS(model, ) + kernel = NUTS(model) mcmc = MCMC(kernel, 2, 5, num_chains=2, chain_method=chain_method, jit_model_args=compile_args) mcmc.run(random.PRNGKey(0), data) From 2f616617d02e1057bdb1632998c4fb692372edd7 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 27 Nov 2019 14:49:00 -0800 Subject: [PATCH 2/6] add to fori_collect doc --- numpyro/util.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/numpyro/util.py b/numpyro/util.py index 55e01e086..d734d0856 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -149,6 +149,9 @@ def fori_collect(lower, upper, body_fun, init_val, transform=identity, be any Python collection type containing `np.ndarray` objects. :param transform: a callable to post-process the values returned by `body_fn`. :param progbar: whether to post progress bar updates. + :param bool return_init_state: If `True`, the state at iteration `lower-1`, + where the collection begins, is also returned. This has the same type + as `init_val`. :param `**progbar_opts`: optional additional progress bar arguments. A `diagnostics_fn` can be supplied which when passed the current value from `body_fun` returns a string that is used to update the progress From a756daead5e0fe5f2e604f6a5424404a15024f1c Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 27 Nov 2019 16:54:34 -0800 Subject: [PATCH 3/6] address comment --- numpyro/util.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/numpyro/util.py b/numpyro/util.py index d734d0856..3fc2978a9 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -183,8 +183,7 @@ def _body_fn(i, vals): progbar_desc = progbar_opts.pop('progbar_desc', lambda x: '') collection = [] - val = init_val - start_state = None + val, start_state = init_val, init_val with tqdm.trange(upper) as t: for i in t: val = jit(body_fun)(val) @@ -199,7 +198,7 @@ def _body_fn(i, vals): collection = np.stack(collection) unravel_collection = vmap(unravel_fn)(collection) - return unravel_collection, start_state if return_init_state else start_state + return (unravel_collection, start_state) if return_init_state else unravel_collection def copy_docs_from(source_class, full_text=False): From 9a5e7cf98aa744bad8c7567347b36d76959d68e8 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 27 Nov 2019 17:54:17 -0800 Subject: [PATCH 4/6] address comment in fori_loop --- numpyro/infer/mcmc.py | 11 +++++++---- numpyro/util.py | 5 +++-- test/test_mcmc.py | 2 +- 3 files changed, 11 insertions(+), 7 deletions(-) diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 1ad835839..25d834e39 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -668,7 +668,7 @@ class MCMC(object): def __init__(self, sampler, num_warmup, - num_samples, + num_samples=0, num_chains=1, constrain_fn=None, chain_method='parallel', @@ -749,12 +749,14 @@ def _single_chain_mcmc(self, rng_key, init_params, args, kwargs, return_init_state=True, progbar_desc=functools.partial(get_progbar_desc_str, self.num_warmup), diagnostics_fn=diagnostics) - states, self._init_state = collect_vals + states, init_state = collect_vals + self._init_state = init_state[0] if len(collect_fields) == 1: states = (states,) states = dict(zip(collect_fields, states)) # Apply constraints if number of samples is non-zero - if len(tree_flatten(states['z'])[0]) > 0: + site_values = tree_flatten(states['z'])[0] + if len(site_values) > 0 and site_values[0].size > 0: states['z'] = lax.map(self.constrain_fn, states['z']) return states @@ -808,7 +810,8 @@ def run(self, rng_key, *args, extra_fields=(), collect_warmup=False, assert isinstance(extra_fields, (tuple, list)) collect_fields = tuple(set(('z', 'diverging') + tuple(extra_fields))) if self.num_chains == 1: - states_flat = self._single_chain_mcmc(rng_key, init_params, args, kwargs, collect_fields, collect_warmup) + states_flat = self._single_chain_mcmc(rng_key, init_params, args, kwargs, collect_fields, + collect_warmup, reuse_warmup) states = tree_map(lambda x: x[np.newaxis, ...], states_flat) else: rng_keys = random.split(rng_key, self.num_chains) diff --git a/numpyro/util.py b/numpyro/util.py index 3fc2978a9..81b5afa69 100644 --- a/numpyro/util.py +++ b/numpyro/util.py @@ -160,7 +160,7 @@ def fori_collect(lower, upper, body_fun, init_val, transform=identity, :return: collection with the same type as `init_val` with values collected along the leading axis of `np.ndarray` objects. """ - assert lower < upper + assert lower <= upper init_val_flat, unravel_fn = ravel_pytree(transform(init_val)) ravel_fn = lambda x: ravel_pytree(transform(x))[0] # noqa: E731 @@ -195,7 +195,8 @@ def _body_fn(i, vals): if diagnostics_fn: t.set_postfix_str(diagnostics_fn(val), refresh=False) - collection = np.stack(collection) + collection = np.stack(collection) if len(collection) > 0 else \ + np.zeros((upper - lower,) + init_val_flat.shape) unravel_collection = vmap(unravel_fn)(collection) return (unravel_collection, start_state) if return_init_state else unravel_collection diff --git a/test/test_mcmc.py b/test/test_mcmc.py index 4afc4c438..b04c7c8aa 100644 --- a/test/test_mcmc.py +++ b/test/test_mcmc.py @@ -238,7 +238,7 @@ def model(data): data = dist.Normal(true_mean, true_std).sample(random.PRNGKey(1), (2000,)) kernel = NUTS(model=model) # num_samples must be >= 1; otherwise we'll get an error - mcmc = MCMC(kernel, num_warmup, num_samples=1, progress_bar=False) + mcmc = MCMC(kernel, num_warmup) mcmc.run(random.PRNGKey(2), data) mcmc.num_samples = num_samples mcmc.run(random.PRNGKey(2), data, reuse_warmup=True) From 3ec315b5bde861a97f3cefb78ccc0e4dd6d40886 Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 27 Nov 2019 19:08:35 -0800 Subject: [PATCH 5/6] address comments --- numpyro/infer/mcmc.py | 37 +++++++++++++++++++------------------ test/test_mcmc.py | 3 ++- 2 files changed, 21 insertions(+), 19 deletions(-) diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 25d834e39..4d8f7eaea 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -688,7 +688,8 @@ def __init__(self, self._jit_model_args = jit_model_args self._states = None self._states_flat = None - self._init_state = None + # HMCState returned by last warmup + self._warmup_state = None self._cache = {} def _get_cached_fn(self): @@ -725,13 +726,13 @@ def _hashable(x): return fn def _single_chain_mcmc(self, rng_key, init_params, args, kwargs, - collect_fields=('z',), collect_warmup=False, reuse_warmup=False): + collect_fields=('z',), collect_warmup=False, reuse_warmup_state=False): num_warmup = self.num_warmup - if reuse_warmup: - if self._init_state is None: + if reuse_warmup_state: + if self._warmup_state is None: raise ValueError('No `init_state` found; warmup results cannot be reused.') num_warmup = 0 - init_state = self._init_state + init_state = self._warmup_state else: init_state = self.sampler.init(rng_key, self.num_warmup, init_params, model_args=args, model_kwargs=kwargs) @@ -747,10 +748,10 @@ def _single_chain_mcmc(self, rng_key, init_params, args, kwargs, transform=lambda x: collect_fn(x[0]), # noqa: E731 progbar=self.progress_bar, return_init_state=True, - progbar_desc=functools.partial(get_progbar_desc_str, self.num_warmup), + progbar_desc=functools.partial(get_progbar_desc_str, num_warmup), diagnostics_fn=diagnostics) - states, init_state = collect_vals - self._init_state = init_state[0] + states, warmup_state = collect_vals + self._warmup_state = warmup_state[0]._replace(i=0) if len(collect_fields) == 1: states = (states,) states = dict(zip(collect_fields, states)) @@ -760,19 +761,19 @@ def _single_chain_mcmc(self, rng_key, init_params, args, kwargs, states['z'] = lax.map(self.constrain_fn, states['z']) return states - def _single_chain_jit_args(self, init, collect_fields=('z',), collect_warmup=False, reuse_warmup=False): + def _single_chain_jit_args(self, init, collect_fields=('z',), collect_warmup=False, reuse_warmup_state=False): return self._single_chain_mcmc(*init, collect_fields=collect_fields, - collect_warmup=collect_warmup, reuse_warmup=reuse_warmup) + collect_warmup=collect_warmup, reuse_warmup_state=reuse_warmup_state) def _single_chain_nojit_args(self, init, model_args, model_kwargs, collect_fields=('z',), - collect_warmup=False, reuse_warmup=False): + collect_warmup=False, reuse_warmup_state=False): return self._single_chain_mcmc(*init, model_args, model_kwargs, collect_fields=collect_fields, collect_warmup=collect_warmup, - reuse_warmup=reuse_warmup) + reuse_warmup_state=reuse_warmup_state) def run(self, rng_key, *args, extra_fields=(), collect_warmup=False, - init_params=None, reuse_warmup=False, **kwargs): + init_params=None, reuse_warmup_state=False, **kwargs): """ Run the MCMC samplers and collect samples. @@ -786,8 +787,8 @@ def run(self, rng_key, *args, extra_fields=(), collect_warmup=False, to `False`. :param init_params: Initial parameters to begin sampling. The type must be consistent with the input type to `potential_fn`. - :param bool reuse_warmup: If `True`, sampling would make use of the initial state and - adaptation parameters from the last warmup run. + :param bool reuse_warmup_state: If `True`, sampling would make use of the last state and + adaptation parameters from the previous warmup run. :param kwargs: Keyword arguments to be provided to the :meth:`numpyro.infer.mcmc.MCMCKernel.init` method. These are typically the keyword arguments needed by the `model`. """ @@ -811,7 +812,7 @@ def run(self, rng_key, *args, extra_fields=(), collect_warmup=False, collect_fields = tuple(set(('z', 'diverging') + tuple(extra_fields))) if self.num_chains == 1: states_flat = self._single_chain_mcmc(rng_key, init_params, args, kwargs, collect_fields, - collect_warmup, reuse_warmup) + collect_warmup, reuse_warmup_state) states = tree_map(lambda x: x[np.newaxis, ...], states_flat) else: rng_keys = random.split(rng_key, self.num_chains) @@ -819,14 +820,14 @@ def run(self, rng_key, *args, extra_fields=(), collect_warmup=False, partial_map_fn = partial(self._single_chain_jit_args, collect_fields=collect_fields, collect_warmup=collect_warmup, - reuse_warmup=reuse_warmup) + reuse_warmup_state=reuse_warmup_state) else: partial_map_fn = partial(self._single_chain_nojit_args, model_args=args, model_kwargs=kwargs, collect_fields=collect_fields, collect_warmup=collect_warmup, - reuse_warmup=reuse_warmup) + reuse_warmup_state=reuse_warmup_state) if chain_method == 'sequential': if self.progress_bar: map_fn = partial(_laxmap, partial_map_fn) diff --git a/test/test_mcmc.py b/test/test_mcmc.py index b04c7c8aa..1a1e17fb4 100644 --- a/test/test_mcmc.py +++ b/test/test_mcmc.py @@ -241,7 +241,7 @@ def model(data): mcmc = MCMC(kernel, num_warmup) mcmc.run(random.PRNGKey(2), data) mcmc.num_samples = num_samples - mcmc.run(random.PRNGKey(2), data, reuse_warmup=True) + mcmc.run(random.PRNGKey(2), data, reuse_warmup_state=True) samples = mcmc.get_samples() assert_allclose(np.mean(samples['mean']), true_mean, rtol=0.05) assert_allclose(np.mean(samples['std']), true_std, rtol=0.05) @@ -398,6 +398,7 @@ def model(data): kernel = NUTS(model) mcmc = MCMC(kernel, 2, 5, num_chains=2, chain_method=chain_method, jit_model_args=compile_args) mcmc.run(random.PRNGKey(0), data) + mcmc.run(random.PRNGKey(1), data, reuse_warmup_state=True) def test_extra_fields(): From 6be4130115d5dd6d33681288b595e2713ced074a Mon Sep 17 00:00:00 2001 From: Neeraj Pradhan Date: Wed, 27 Nov 2019 19:25:52 -0800 Subject: [PATCH 6/6] reset state.i to 0 to avoid confusion --- numpyro/infer/mcmc.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/numpyro/infer/mcmc.py b/numpyro/infer/mcmc.py index 4d8f7eaea..d66d81ad2 100644 --- a/numpyro/infer/mcmc.py +++ b/numpyro/infer/mcmc.py @@ -751,7 +751,11 @@ def _single_chain_mcmc(self, rng_key, init_params, args, kwargs, progbar_desc=functools.partial(get_progbar_desc_str, num_warmup), diagnostics_fn=diagnostics) states, warmup_state = collect_vals - self._warmup_state = warmup_state[0]._replace(i=0) + # Get first argument of type `HMCState` + warmup_state = warmup_state[0] + # Note that setting i = 0 may result in recompilation due to python integers having + # weak type + self._warmup_state = warmup_state._replace(i=np.zeros_like(warmup_state.i)) if len(collect_fields) == 1: states = (states,) states = dict(zip(collect_fields, states))