Skip to content
This repository has been archived by the owner on Dec 18, 2023. It is now read-only.

Robust U-turn condition #864

Closed
wants to merge 2 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -135,39 +135,89 @@ def _build_tree(self, root: _TreeNode, tree_depth: int, args: _TreeArgs) -> _Tre
args=args,
)

# uniform progressive sampling (Appendix 3.1 of [2])
log_weight = torch.logaddexp(sub_tree.log_weight, other_sub_tree.log_weight)
log_tree_prob = other_sub_tree.log_weight - log_weight

# if log_tree_prob is NaN then this will evaluate to False; this can happen when
# the log weight of both trees are -inf
if torch.log1p(-torch.rand(())) <= log_tree_prob:
selected_subtree = other_sub_tree
return self._combine_tree(
sub_tree, other_sub_tree, args.direction, biased=False
)

def _combine_tree(
self, old_tree: _Tree, new_tree: _Tree, direction: int, biased: bool
) -> _Tree:
"""Combine the old tree and the new tree into a single (large) tree. The new
tree will be add to the left of the old tree if direction is -1, otherwise it
will be add to the right. If biased is True, then we will prefer choosing from
new tree (which is away from the starting location) than old tree when sampling
the next state from the trajectory. This function assumes old_tree is not
turned or diverged."""
# if old tree hsa turned or diverged, then we shouldn't build the new tree in
# the first place
assert not old_tree.turned_or_diverged
# log of the sum of the weights from both trees
log_weight = torch.logaddexp(old_tree.log_weight, new_tree.log_weight)

if new_tree.turned_or_diverged:
selected_subtree = old_tree
else:
# progressively sample from the trajectory
if biased:
# biased progressive sampling (Appendix 3.2 of [2])
log_tree_prob = new_tree.log_weight - old_tree.log_weight
else:
# uniform progressive sampling (Appendix 3.1 of [2])
log_tree_prob = new_tree.log_weight - log_weight

if torch.rand_like(log_tree_prob).log() < log_tree_prob:
selected_subtree = new_tree
else:
selected_subtree = old_tree

if direction == -1:
left_tree, right_tree = new_tree, old_tree
else:
selected_subtree = sub_tree
left_tree, right_tree = old_tree, new_tree

left_state = other_sub_tree.left if args.direction == -1 else sub_tree.left
right_state = sub_tree.right if args.direction == -1 else other_sub_tree.right
sum_momentums = {
node: sub_tree.sum_momentums[node] + other_sub_tree.sum_momentums[node]
for node in sub_tree.sum_momentums
node: left_tree.sum_momentums[node] + right_tree.sum_momentums[node]
for node in left_tree.sum_momentums
}
turned_or_diverged = new_tree.turned_or_diverged or self._is_u_turning(
left_tree.left.momentums,
right_tree.right.momentums,
sum_momentums,
)
# More robust U-turn condition
# https://discourse.mc-stan.org/t/nuts-misses-u-turns-runs-in-circles-until-max-treedepth/9727
if not turned_or_diverged and right_tree.num_proposals > 1:
extended_sum_momentums = {
node: left_tree.sum_momentums[node] + right_tree.left.momentums[node]
for node in sum_momentums
}
turned_or_diverged = self._is_u_turning(
left_tree.left.momentums,
right_tree.left.momentums,
extended_sum_momentums,
)
if not turned_or_diverged and left_tree.num_proposals > 1:
extended_sum_momentums = {
node: right_tree.sum_momentums[node] + left_tree.right.momentums[node]
for node in sum_momentums
}
turned_or_diverged = self._is_u_turning(
left_tree.right.momentums,
right_tree.right.momentums,
extended_sum_momentums,
)

return _Tree(
left=left_state,
right=right_state,
left=left_tree.left,
right=right_tree.right,
proposal=selected_subtree.proposal,
pe=selected_subtree.pe,
pe_grad=selected_subtree.pe_grad,
log_weight=log_weight,
sum_momentums=sum_momentums,
sum_accept_prob=sub_tree.sum_accept_prob + other_sub_tree.sum_accept_prob,
num_proposals=sub_tree.num_proposals + other_sub_tree.num_proposals,
turned_or_diverged=other_sub_tree.turned_or_diverged
or self._is_u_turning(
left_state.momentums,
right_state.momentums,
sum_momentums,
),
sum_accept_prob=old_tree.sum_accept_prob + new_tree.sum_accept_prob,
num_proposals=old_tree.num_proposals + new_tree.num_proposals,
turned_or_diverged=turned_or_diverged,
)

def propose(self, world: Optional[SimpleWorld] = None) -> SimpleWorld:
Expand All @@ -184,52 +234,32 @@ def propose(self, world: Optional[SimpleWorld] = None) -> SimpleWorld:
else:
# this is a more stable way to sample from log(Uniform(0, exp(-current_energy)))
log_slice = torch.log1p(-torch.rand(())) - current_energy
left_tree_node = right_tree_node = _TreeNode(
self.world, momentums, self._pe_grad
tree_node = _TreeNode(self.world, momentums, self._pe_grad)
tree = _Tree(
left=tree_node,
right=tree_node,
proposal=self.world,
pe=self._pe,
pe_grad=self._pe_grad,
log_weight=torch.tensor(0.0), # log accept prob of staying at current state
sum_momentums=momentums,
sum_accept_prob=torch.tensor(0.0),
num_proposals=0,
turned_or_diverged=False,
)
log_weight = torch.tensor(0.0) # log accept prob of staying at current state
sum_accept_prob = 0.0
num_proposals = 0
sum_momentums = momentums

for j in range(self._max_tree_depth):
direction = 1 if torch.rand(()) > 0.5 else -1
tree_args = _TreeArgs(log_slice, direction, self.step_size, current_energy)
if direction == -1:
tree = self._build_tree(left_tree_node, j, tree_args)
left_tree_node = tree.left
new_tree = self._build_tree(tree.left, j, tree_args)
else:
tree = self._build_tree(right_tree_node, j, tree_args)
right_tree_node = tree.right

sum_accept_prob += tree.sum_accept_prob
num_proposals += tree.num_proposals
new_tree = self._build_tree(tree.right, j, tree_args)

tree = self._combine_tree(tree, new_tree, direction, biased=True)
if tree.turned_or_diverged:
break

# biased progressive sampling (Appendix 3.2 of [2])
log_tree_prob = tree.log_weight - log_weight

# choose new world by randomly sample from proposed worlds
if torch.log1p(-torch.rand(())) <= log_tree_prob:
self.world, self._pe, self._pe_grad = (
tree.proposal,
tree.pe,
tree.pe_grad,
)
sum_momentums = {
node: sum_momentums[node] + tree.sum_momentums[node]
for node in sum_momentums
}
if self._is_u_turning(
left_tree_node.momentums,
right_tree_node.momentums,
sum_momentums,
):
break

log_weight = torch.logaddexp(log_weight, tree.log_weight)

self._alpha = sum_accept_prob / num_proposals
self.world, self._pe, self._pe_grad = tree.proposal, tree.pe, tree.pe_grad
self._alpha = tree.sum_accept_prob / tree.num_proposals
return self.world