Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Ability to start MCMC sampling from the same warmup state #469

Merged
merged 6 commits into from
Nov 28, 2019

Conversation

neerajprad
Copy link
Member

Addresses #467.

This introduces a return_init_state argument to fori_collect, which if True returns the lower-1 state. This is used by the reuse_warmup argument to MCMC.run as follows:

# run for 1 sample; num_samples=0 will throw an error
mcmc = MCMC(num_warmup, num_samples=1)
mcmc.run(rng=.., *args, **kwargs)
# change number of samples
mcmc.num_samples = 1000
# runs from the last warmup state avoiding compilation
mcmc.run(rng=..., *args, reuse_warmup=True, **kwargs)

This pattern can probably be improved - in particular num_samples=1 seems a bit odd. Ideally, we should make num_samples=0 by default in which case running mcmc without num_samples argument will simply compile the model and store the initial state from warmup. This however requires much more refactoring, whose value is questionable at the moment. If this turns out to be a useful pattern, we can make this prettier.

@neerajprad
Copy link
Member Author

@fehiepsi - let me know if you have other suggestions on how to achieve this.

Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This however requires much more refactoring

I think it is valuable and I prefer doing so. With that, we don't need to have reuse_warmup, so the way to achieve a similar result will be

mcmc = MCMC(NUTS(model), num_warmup, 0)
mcmc.run()
mcmc.num_samples = 1000
mcmc.run()

In addition, I prefer to name it current_state than init_state.

numpyro/util.py Outdated Show resolved Hide resolved
@neerajprad
Copy link
Member Author

neerajprad commented Nov 28, 2019

With that, we don't need to have reuse_warmup, so the way to achieve a similar result will be

I think you'll need to also set mcmc.num_warmup = 0 in that example? So what you are saying is that internally if num_warmup=0 and we have some init_state, we will use that instead?

I tried doing this, but this adds unneeded complexity to fori_collect since many collection primitives assume that there's at least 1 item in the collection. I haven't tried too hard though, and if it can be done without too many conditional checks in the code, it will be nice. If not, it's unclear what we are gaining. I'm happy to open an issue and address this later (we can also gate this reuse_warmup behind an experimental flag).

@fehiepsi
Copy link
Member

@neerajprad To bypass the restriction num_samples=0, I think you can change the collection line in fori_collect to

collection = np.stack(collection) if len(collection) > 0 else np.zeros((upper - lower,) + init_val_flat.shape)

and in MCMC._single_chain_mcmc

        # Apply constraints if number of samples is non-zero
        site_values = tree_flatten(states['z'])[0]
        if len(site_values) > 0 and site_values[0].size > 0:
            states['z'] = lax.map(self.constrain_fn, states['z'])

I tested that this works with funnel example.

@neerajprad
Copy link
Member Author

Thanks for your suggestion, @fehiepsi. I'm also noticing another issue. I'll address these, and ping you again for a review. 😄

@fehiepsi
Copy link
Member

Re setting num_warmup=0: I believe we can use init_state.i as a marker for lower/upper logic of fori_collect. Let me think a bit more about it.

@neerajprad
Copy link
Member Author

neerajprad commented Nov 28, 2019

Re setting num_warmup=0: I believe we can use init_state.i as a marker for lower/upper logic of fori_collect. Let me think a bit more about it.

This isn't too important though; I think num_samples = 0 makes sense. Beyond that whether we have to set num_warmup=0 or use reuse_warmup doesn't really matter. I have a slight preference for the latter, so that things are explicit and there is no magical behavior for num_warmup=0 (e.g. if someone is expecting hmc.init to be triggered even with num_warmup=0, that won't happen with this design).

@fehiepsi
Copy link
Member

fehiepsi commented Nov 28, 2019

I meant we don't need to set warmup=0. init_state/current_state already holds the current step. I think we can make the internal of MCMC to be:

mcmc = MCMC(NUTS(model), num_warmup, 0)
mcmc.run()
assert mcmc._init_state.i == num_warmup + 0
mcmc.num_samples = 1000
mcmc.run()
assert mcmc._init_state.i == num_warmup + num_samples

To achieve that, we can change the lower/upper of fori_collect in MCMC to be (modulo multi-chains)

if collect_warmup or (self._init_state.i >= num_warmup):
    lower = 0
else:
    lower = num_warmup - self._init_state.i
upper = self.num_warmup + self.num_samples - self._init_state.i

This way, we can set num_samples = 1000; then later we can set num_samples = 10000 to run the remaining 9000 steps. (of course after collecting the new 9000 states, we concatenate with the current 1000 samples (instead of replacing it).

@neerajprad
Copy link
Member Author

Okay, let me refactor based on your suggestions first, and post an update.

@neerajprad
Copy link
Member Author

neerajprad commented Nov 28, 2019

@fehiepsi - Regarding your last point, I think we still need the reuse_warmup flag, otherwise we will wrongly use the _init_state from last time, even if the user want to do a fresh warmup run (e.g. on new data). As an example, if we use the last warmup state, test_reuse_mcmc_run will fail, so this shouldn't be the default behavior (it should only get triggered by a flag).

@fehiepsi
Copy link
Member

Yeah, that's tricky. With that, run has two meaning: keep running and "reset and run again"... :D

May I suggest another name for the flag: use_current_state?

@neerajprad
Copy link
Member Author

neerajprad commented Nov 28, 2019

With that, run has two meaning: keep running and "reset and run again"... :D

I suppose .run can be used in many ways, where the underlying motivation is that you want to share something with the previous run - e.g. share compiled code but run on fresh data, or share everything but run with a fresh seed, or share warmup, etc.

May I suggest another name for the flag: use_current_state?

You mean change reuse_warmup --> use_current_state? The problem with that name is that it isn't really using the most current state from sampling. It is actually just using the last state returned by warmup. How about reuse_warmup_state?

@fehiepsi
Copy link
Member

fehiepsi commented Nov 28, 2019

How about reuse_warmup_state?

Yeah, this is a better name for me. Please forget about use_current_state, it requires fori_collect to return the last state (instead of start_state). I don't intend to run with 1000 samples, then (do something and) keep running with 1000 samples, then (do something and) keep running to collect the next 1000 samples,... (though I believe people usually do this way to collect samples from MCMC hubs - as I understand from parallel HMC papers which I read). Let's revisit it when necessary.

Btw, I think you need to deal with multi-chains ;).

@neerajprad
Copy link
Member Author

(though I believe people usually do this way to collect samples from MCMC hubs - as I understand from parallel HMC papers which I read). Let's revisit it when necessary.

Yes, I do realize that it might be important for certain use cases. And as you mention, we would rather have some users request it first before adding it in.

Btw, I think you need to deal with multi-chains ;).

I'll take a look. 😅

@neerajprad
Copy link
Member Author

@fehiepsi - I didn't see anything immediately wrong with multiple chains, and I enabled this in the smoke test, which should pass. Let me know if I missed something.

@neerajprad neerajprad mentioned this pull request Nov 28, 2019
4 tasks
Copy link
Member

@fehiepsi fehiepsi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks great! :D multi-chain should work with the current implementation (I thought that init_state = init_state[0] selects the first chain of multi-chain...).

@fehiepsi
Copy link
Member

@neerajprad Can I merge this now?

@neerajprad
Copy link
Member Author

Yes, lets merge this. If there are any follow-up issues, I'll put up a separate PR.

@fehiepsi fehiepsi merged commit 0d6839d into master Nov 28, 2019
@fehiepsi fehiepsi deleted the reuse-warmup branch November 28, 2019 04:04
@neerajprad neerajprad mentioned this pull request Nov 28, 2019
3 tasks
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants