Summary:
Pull Request resolved: facebookresearch#1369
Key Changes:
- This diff implements the multiprocessing logic and introduces a new argument, `run_in_parallel`, so users can choose to run multi chain inference in parallel (in subprocesses).
- For the progress bar, it seems like as long as we pass `tqdm` a lock and use the position arg in the progress bar, `tqdm` can correctly update the progress bar for each subprocess, so we don't need to keep a subprocess dedicated to update progress bar like Pyro does :). On downside of this approach comparing to a dedicated process is that in Jupyter notebook, it seems like the order of the progress bar can be messed up (so the progress bar for the 5th chain can appear on the 1st row, see screenshot below), but that shouldn't matter in our use case.
{F706198308}
- (The screenshot is taken from a toy snippet to test the progress bar, not from BM :))
- We also need to change how samples are gathered, because sending `RVIdentifier` back and forth between processes can change its hash values. As a result, we can run into `KeyError` when merging dictionaries of samples sent to the main process. The solution here is to return a list of `Tensor`s instead and use the order of queries to determine which `Tensor` correspond to which `RVIdentifier`.
- User can use the new `mp_context` argument to control how to form a new subprocess ([see multiprocessing doc for details](https://docs.python.org/3.8/library/multiprocessing.html#contexts-and-start-methods))
- **Note**: for gradient-based methods such as NMC and NUTS, the usual caveats of running autograd with fork-based multiprocessing still applies: https://github.com/pytorch/pytorch/wiki/Autograd-and-Fork. Seems like autograd initializes some internal state when it is being executed for the first time, and fork-mode multiprocessing will copy that state into subprocesses, which can be problematic, so PyTorch recommends using "spawn" mode for multiprocessing, but spawn mode doesn't work in interactive environment such as Jupyter notebooks. One way to work around this issue in Jupyter notebook is to keep using the default "fork" mode, but do not initialize the autograd state in the main process (i.e. always run inferences in subprocesses). This is not an elegant solution, but at least it works. From previous conversation with OpenTeams, it seems like Dask does not triggers PyTorch's autograd warning, so we should still look into that to see if it can be a better long term solution.
- When `run_in_parallel` is `True`, we will pre-sample the seed for each chainand pass that to subprocesses. This will ensure that the RNG for each chain is set to a different state
- We could use the same mechanism to set the seed for non-parallel inference as well, but doing so will change the stochastic behavior of our existing tutorials and use cases, so I'd rather not do that right now since there have been a lot of changes already in this diff :)
Differential Revision: D34574082
fbshipit-source-id: 32237561392a0e7b9a4b7392a297fdc35642f331