-
Notifications
You must be signed in to change notification settings - Fork 238
Iterative NUTS
Du Phan, Neeraj Pradhan
In this note, we would like to present how we convert the recursive nature of NUTS to an iterative one, which integrates nicely with JAX compiling mechanism. Two main references which we use for NUTS sampler are:
-
The No-U-Turn Sampler: Adaptively Setting Path Lengths in Hamiltonian Monte Carlo
, arxiv
Matthew D. Hoffman, Andrew Gelman -
A Conceptual Introduction to Hamiltonian Monte Carlo
, arxiv
Michael Betancourt
While for HMC, we know the number of leapfrog/verlet steps at the beginning of each trajectory, we don't know how far (or how many number of leapfrog steps) we will go with NUTS. We will keep moving as far as we can until a turning
condition happens (indeed, there are two more conditions: we are going too far and the energy is diverging, but we will skip the discussion about those conditions because they are easy to check). While moving, we are constructing a binary tree which keeps track of all information we need: information of the left leaf, right leaf, and a proposal for the next sampling step. The following figure from reference [1] nicely illustrates this process.
The algorithm works as follows:
- We start with a node called a
basetree
. - Keep
doubling
the tree untilturning
happens:- Choose a
direction
: forward (orright
) or backward (orleft
). - Build a
subtree
(by following that direction) with the samedepth = d
as the current tree by recursion, where the base case of this recursion process is to build a node (or a subtree withdepth = 0
):- Recursively build the first half (with
depth = d - 1
) of subtree. - If the first half is
turning
, we stop. Otherwise, recursively build the second half withdepth = d - 1
. - This subtree is
turning
if: the first half is turning, or the second half is turning, or the combined one is turning.
- Recursively build the first half (with
- Choose a
By using recursion together with a neat transition kernel (which governs the probability to accept a proposal from the build process, see more at section A.3 of reference 2) from authors of NUTS, the memory requirement is O(d)
, where d
is the depth of subtree. This is very effective given that to build that subtree, we have moved by 2^d
steps.
In this section, we will illustrate how to build a subtree with depth 4 iteratively (instead of using recursion as above). This means we want to construct 16 nodes (numbered from 0 -> 15
) for the tree. Instead of recursively build subtrees 0 -> 7
and 8 -> 15
, we iteratively go straight from 0
to 15
but will stop when turning
condition happens. The trickiest parts are:
- Decide a stopping condition which is equivalent to the recursive algorithm,
- Maintain the memory effectiveness of the recursion.
For example, at node 3, we need to check the turning conditions for the following trees: 0 -> 3
, 2 -> 3
. At node 7, we check the turning conditions for the following trees: 0 -> 7
, 4 -> 7
, 6 -> 7
. At node 12, we don't need to check the turning condition. But at node 13, we need to check the turning condition of the tree 12->13
. It is helpful to draw the binary tree and track down the process.
The first case requires to store information at nodes 0, 2. The second case requires to store information at nodes 0, 4, 6. The forth case requires to store information at node 12. The number of nodes which we need to store will dynamically change when we move. But the maximum number will be 4, which is the depth of the subtree we need to build. This maximum is attained at node 15, where we need to check the turning conditions for trees: 0 -> 15
, 8 -> 15
, 12 -> 15
, 14 -> 15
.
First, we create a storage R[4]
to store information. Then, the whole process works as follows
- Step 0:
R[0] = node_0
- Step 1: check turning condition of
R[0]
(ornode_0
) andnode_1
- Step 2:
R[1] = node_2
- Step 3: check turning condition of
R[0]
andnode_3
,R[1]
andnode_3
- Step 4:
R[1] = node_4
(we update at index 1 becausenode_2
is no longer needed for further process) - Step 5: check turning condition of
R[1]
andnode_5
(though it is reasonable, we won't check turning condition ofR[0]
andnode_5
; the reason is to to make it equivalent to recursive algorithm) - Step 6:
R[2] = node_6
- Step 7: check turning condition of
R[0]
andnode_7
,R[1]
andnode_7
,R[2]
andnode_7
- Step 8:
R[1] = node_8
- Step 9: check turning condition of
R[1]
andnode_9
- Step 10:
R[2] = node_10
- Step 11: check turning condition of
R[1]
andnode_11
,R[2]
andnode_11
- Step 12:
R[2] = node_12
- Step 13: check turning condition of
R[2]
andnode_13
- Step 14:
R[3] = node_14
- Step 15: check turning condition of
R[0]
andnode_15
,R[1]
andnode_15
,R[2]
andnode_15
,R[3]
andnode_15
In summary, at even steps, we update the storage and at odd steps, we verify the turning conditions. If the turning conditions are matched, we stop. Otherwise, we go to the next step.
Now, we'll discuss two technical points of iterative scheme:
- At even steps, which index of the memory we need to update?
- At odd steps, which portion of the memory we need to check for turning conditions with the current node?
Though it seems a bit tricky, things will be easier to see when we look at the binary representation of node indices. The binary representation of node 7 is 111, which corresponds to the right leaves of trees with depths 1, 2, 3. The binary representation of node 12 is 1100, which corresponds to the left leaf of a tree with depth 1, left leaf of a tree with depth 2, right leaf of a tree with depth 3, right leaf of a tree with depth 4. In summary, 1 at position i
(counted from right to left in binary representation) corresponds to the right leaf of a tree with depth i; 0 at position i
corresponds to the left leaf.
Even step: The mapping from node index (presented in binary representation) to memory index is: 0=0 -> 0
, 2=10 -> 1
, 4=100 -> 1
, 6=110 -> 2
, 8=1000 -> 1
, 10=1010 -> 2
, 12=1100 -> 2
, 14=1110 -> 3
. We can see that the memory index we need to update is the number of 1 in the binary form of node index. So we can use bitcount algorithm to calculate the node index: R_idx = bitcount(node_idx)
.
Odd step: We need to decide idx_min
and idx_max
of the memory so that we'll check the turning conditions for all indices from idx_min
and idx_max
w.r.t. the current node.
-
idx_max
is the node index which we have updated in the previous even step. So assume we are at node9=1001
,idx_max = bitcount(1001 - 1) = bitcount(100) = 1
. In other words, we count the number of 1 except the last one in binary form. - Instead of calculating
idx_min
directly, we will calculate the number of indices we need to check turning conditions, which is to computenum_idxs = idx_max - idx_min + 1
. The mapping from node index (presented in binary representation) tonum_idxs
is:1=1 -> 1
,3=11 -> 2
,5=101 -> 1
,7=111 -> 3
,9=1001 -> 1
,11=1011 -> 2
,13=1101 -> 1
,15=1111 -> 4
. We can see thatnum_idxs
is the number of contiguous last one bits of the binary form (which is the contiguous right leaves of subtrees in the whole binary tree).
By resolving the above two technical points, we have succeeded in converting recursive nature of NUTS to an iterative one. The iterative scheme allows more control on the process (e.g. it might allow adding more stopping conditions to the scheme). But that's for future research. The biggest advantage right now is it allows to use [JAX] to compile the computation of whole trajectory, so the overhead cost will be hugely reduced. We also want to note that recursion has been one of technical issues of having a NUTS implementation in graph (v.s. eager) mode, which is mentioned in Simple, Distributed, and Accelerated Probabilistic Programming paper. We also hope that our solution will help researchers in HMC investigate more cool features of NUTS.