Skip to content

Commit

Permalink
Adding option to perform more robust no-U-turn check
Browse files Browse the repository at this point in the history
Adds keyword argument to dynamic integration transitions and samplers to 
enable extra subtree termination criterion checks as described in 
stan-dev/stan#2800. Extra subtree checks are set 
to be enabled by default for the `DynamicMultinomialHMC` sampler.
  • Loading branch information
matt-graham committed May 7, 2020
1 parent b0b18ed commit be3c4cd
Show file tree
Hide file tree
Showing 2 changed files with 107 additions and 24 deletions.
52 changes: 48 additions & 4 deletions mici/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1229,6 +1229,11 @@ class DynamicMultinomialHMC(HamiltonianMCMC):
relative probability densities of the different candidate states, with the
resampling biased towards states further from the current state.
When used with the default settings of `riemannian_no_u_turn_criterion`
termination criterion and extra subtree checks enabled, this sampler is
equivalent to the default 'NUTS' MCMC algorithm (minus adaptation) used in
Stan as of version v2.23.
References:
1. Hoffman, M.D. and Gelman, A., 2014. The No-U-turn sampler:
Expand All @@ -1241,7 +1246,7 @@ class DynamicMultinomialHMC(HamiltonianMCMC):
def __init__(self, system, integrator, rng,
max_tree_depth=10, max_delta_h=1000,
termination_criterion=trans.riemannian_no_u_turn_criterion,
momentum_transition=None):
do_extra_subtree_checks=True, momentum_transition=None):
"""
Args:
system (mici.systems.System): Hamiltonian system to be simulated.
Expand All @@ -1263,6 +1268,23 @@ def __init__(self, system, integrator, rng,
(sub-)tree being checked and an array containing the sum of the
momentums over the trajectory (sub)-tree. Defaults to
`mici.transitions.riemannian_no_u_turn_criterion`.
do_extra_subtree_checks (bool): Whether to perform additional
termination criterion checks on overlapping subtrees of the
current tree to improve robustness in systems with dynamics
which are well approximated by independent system of simple
harmonic oscillators. In such systems (corresponding to e.g.
a standard normal target distribution and identity metric
matrix representation) at certain step sizes a 'resonant'
behaviour is seen by which the termination criterion fails to
detect that the trajectory has expanded past a half-period i.e.
has 'U-turned' resulting in trajectories continuing to expand,
potentially up until the `max_tree_depth` limit is hit. For
more details see the Stan Discourse discussion at kutt.it/yAkIES
If `do_extra_subtree_checks` is set to `True` additional
termination criterion checks are performed on overlapping
subtrees which help to reduce this resonant behaviour at the
cost of more conservative trajectory termination in some
correlated models and some overhead from additional checks.
momentum_transition (None or mici.transitions.MomentumTransition):
Markov transition kernel which leaves the conditional
distribution on the momentum under the canonical distribution
Expand All @@ -1274,7 +1296,7 @@ def __init__(self, system, integrator, rng,
"""
integration_transition = trans.MultinomialDynamicIntegrationTransition(
system, integrator, max_tree_depth, max_delta_h,
termination_criterion)
termination_criterion, do_extra_subtree_checks)
super().__init__(system, rng, integration_transition,
momentum_transition)

Expand Down Expand Up @@ -1309,6 +1331,11 @@ class DynamicSliceHMC(HamiltonianMCMC):
probability densities of the different candidate states, with the sampling
biased towards states further from the current state.
When used with the default setting of `euclidean_no_u_turn_criterion`
termination criterion and extra subtree checks disabled, this sampler is
equivalent to 'Algorithm 3: Efficient No-U-Turn Sampler' in [1], i.e. the
'classic NUTS' algorithm.
References:
1. Hoffman, M.D. and Gelman, A., 2014. The No-U-turn sampler:
Expand All @@ -1319,7 +1346,7 @@ class DynamicSliceHMC(HamiltonianMCMC):
def __init__(self, system, integrator, rng,
max_tree_depth=10, max_delta_h=1000,
termination_criterion=trans.euclidean_no_u_turn_criterion,
momentum_transition=None):
do_extra_subtree_checks=False, momentum_transition=None):
"""
Args:
system (mici.systems.System): Hamiltonian system to be simulated.
Expand All @@ -1341,6 +1368,23 @@ def __init__(self, system, integrator, rng,
(sub-)tree being checked and an array containing the sum of the
momentums over the trajectory (sub)-tree. Defaults to
`mici.transitions.euclidean_no_u_turn_criterion`.
do_extra_subtree_checks (bool): Whether to perform additional
termination criterion checks on overlapping subtrees of the
current tree to improve robustness in systems with dynamics
which are well approximated by independent system of simple
harmonic oscillators. In such systems (corresponding to e.g.
a standard normal target distribution and identity metric
matrix representation) at certain step sizes a 'resonant'
behaviour is seen by which the termination criterion fails to
detect that the trajectory has expanded past a half-period i.e.
has 'U-turned' resulting in trajectories continuing to expand,
potentially up until the `max_tree_depth` limit is hit. For
more details see the Stan Discourse discussion at kutt.it/yAkIES
If `do_extra_subtree_checks` is set to `True` additional
termination criterion checks are performed on overlapping
subtrees which help to reduce this resonant behaviour at the
cost of more conservative trajectory termination in some
correlated models and some overhead from additional checks.
momentum_transition (None or mici.transitions.MomentumTransition):
Markov transition kernel which leaves the conditional
distribution on the momentum under the canonical distribution
Expand All @@ -1352,7 +1396,7 @@ def __init__(self, system, integrator, rng,
"""
integration_transition = trans.SliceDynamicIntegrationTransition(
system, integrator, max_tree_depth, max_delta_h,
termination_criterion)
termination_criterion, do_extra_subtree_checks)
super().__init__(system, rng, integration_transition,
momentum_transition)

Expand Down
79 changes: 59 additions & 20 deletions mici/transitions.py
Original file line number Diff line number Diff line change
Expand Up @@ -391,7 +391,8 @@ def riemannian_no_u_turn_criterion(system, state_1, state_2, sum_mom):
np.sum(system.dh_dmom(state_2) * sum_mom) < 0)


SubTree = namedtuple('SubTree', ['negative', 'positive', 'sum_mom', 'weight'])
SubTree = namedtuple('SubTree', [
'negative', 'positive', 'sum_mom', 'weight', 'depth'])


class DynamicIntegrationTransition(IntegrationTransition):
Expand All @@ -415,7 +416,8 @@ class DynamicIntegrationTransition(IntegrationTransition):
"""
def __init__(self, system, integrator,
max_tree_depth=10, max_delta_h=1000,
termination_criterion=riemannian_no_u_turn_criterion):
termination_criterion=riemannian_no_u_turn_criterion,
do_extra_subtree_checks=True):
"""
Args:
system (mici.systems.System): Hamiltonian system to be simulated.
Expand All @@ -435,30 +437,65 @@ def __init__(self, system, integrator,
(sub-)tree being checked and an array containing the sum of the
momentums over the trajectory (sub)-tree. Defaults to
`riemannian_no_u_turn_criterion`.
do_extra_subtree_checks (bool): Whether to perform additional
termination criterion checks on overlapping subtrees of the
current tree to improve robustness in systems with dynamics
which are well approximated by independent system of simple
harmonic oscillators. In such systems (corresponding to e.g.
a standard normal target distribution and identity metric
matrix representation) at certain step sizes a 'resonant'
behaviour is seen by which the termination criterion fails to
detect that the trajectory has expanded past a half-period i.e.
has 'U-turned' resulting in trajectories continuing to expand,
potentially up until the `max_tree_depth` limit is hit. For
more details see the Stan Discourse discussion at kutt.it/yAkIES
If `do_extra_subtree_checks` is set to `True` additional
termination criterion checks are performed on overlapping
subtrees which help to reduce this resonant behaviour at the
cost of more conservative trajectory termination in some
correlated models and some overhead from additional checks.
"""
super().__init__(system, integrator)
assert max_tree_depth > 0, 'max_tree_depth must be non-negative'
self.max_tree_depth = max_tree_depth
self.max_delta_h = max_delta_h
self.termination_criterion = termination_criterion
self.do_extra_subtree_checks = do_extra_subtree_checks
self.statistic_types['tree_depth'] = (np.int64, -1)
self.statistic_types['diverging'] = (np.bool, False)

def _termination_criterion(self, tree):
return self.termination_criterion(
self.system, tree.negative, tree.positive, tree.sum_mom)
def _termination_criterion(self, tree, neg_subtree, pos_subtree):
# If performing extra subtree checks evaluate lazily i.e. only evaluate
# if initial whole tree check fails. Extra subtree checks also only
# performed for trees of depth 2 and above (i.e. containing at least
# 4 states) as for trees of depth 1 they are redundant.
if self.termination_criterion(
self.system, tree.negative, tree.positive, tree.sum_mom):
return True
elif tree.depth > 1 and self.do_extra_subtree_checks:
if self.termination_criterion(
self.system, neg_subtree.negative, pos_subtree.negative,
neg_subtree.sum_mom + pos_subtree.negative.mom):
return True
elif self.termination_criterion(
self.system, neg_subtree.positive, pos_subtree.positive,
pos_subtree.sum_mom + neg_subtree.positive.mom):
return True
return False

def _new_leave(self, state, h, aux_info):
return SubTree(
negative=state, positive=state,
sum_mom=np.asarray(state.mom),
weight=self._weight_function(h, aux_info))
negative=state, positive=state, sum_mom=np.asarray(state.mom),
weight=self._weight_function(h, aux_info), depth=0)

def _merge_subtrees(self, negative_tree, positive_tree):
def _merge_subtrees(self, neg_subtree, pos_subtree):
assert neg_subtree.depth == pos_subtree.depth, (
'Cannot merg subtrees of different depths')
return SubTree(
negative=negative_tree.negative, positive=positive_tree.positive,
weight=negative_tree.weight + positive_tree.weight,
sum_mom=negative_tree.sum_mom + positive_tree.sum_mom)
negative=neg_subtree.negative, positive=pos_subtree.positive,
weight=neg_subtree.weight + pos_subtree.weight,
sum_mom=neg_subtree.sum_mom + pos_subtree.sum_mom,
depth=neg_subtree.depth + 1)

def _init_aux_vars(self, state, rng):
return {'h_init': self.system.h(state)}
Expand Down Expand Up @@ -511,15 +548,16 @@ def _build_tree(self, depth, state, stats, rng, aux_vars):
if terminate:
return terminate, None, None
# merge two subtrees accounting for integration direction
tree = (self._merge_subtrees(inner_tree, outer_tree) if state.dir == 1
else self._merge_subtrees(outer_tree, inner_tree))
neg_subtree = inner_tree if state.dir == 1 else outer_tree
pos_subtree = outer_tree if state.dir == 1 else inner_tree
tree = self._merge_subtrees(neg_subtree, pos_subtree)
# sample new proposal from two subtree proposals according to weights
accept_outer_prob = self._weight_ratio(outer_tree.weight, tree.weight)
proposal = (
outer_proposal if rng.uniform() < accept_outer_prob else
inner_proposal)
# check termination criterion on new tree
terminate = self._termination_criterion(tree)
# check termination criterion on tree and subtrees
terminate = self._termination_criterion(tree, neg_subtree, pos_subtree)
return terminate, tree, proposal

def sample(self, state, rng):
Expand All @@ -543,10 +581,11 @@ def sample(self, state, rng):
if rng.uniform() < self._weight_ratio(new_tree.weight, tree.weight):
next_state = new_proposal
# merge new subtree into current tree accounting for direction
tree = (self._merge_subtrees(tree, new_tree) if direction == 1 else
self._merge_subtrees(new_tree, tree))
# check termination criterion on new tree
if self._termination_criterion(tree):
neg_subtree = tree if direction == 1 else new_tree
pos_subtree = new_tree if direction == 1 else tree
tree = self._merge_subtrees(neg_subtree, pos_subtree)
# check termination criterion on new tree and subtrees
if self._termination_criterion(tree, neg_subtree, pos_subtree):
break
sum_acc_prob = stats.pop('sum_acc_prob')
if stats['n_step'] > 0:
Expand Down

0 comments on commit be3c4cd

Please sign in to comment.