Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Expose energy for bfmi diagnostic #369

Merged
merged 2 commits into from
Sep 30, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions numpyro/hmc_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

TreeInfo = namedtuple('TreeInfo', ['z_left', 'r_left', 'z_left_grad',
'z_right', 'r_right', 'z_right_grad',
'z_proposal', 'z_proposal_pe', 'z_proposal_grad',
'z_proposal', 'z_proposal_pe', 'z_proposal_grad', 'z_proposal_energy',
'depth', 'weight', 'r_sum', 'turning', 'diverging',
'sum_accept_probs', 'num_proposals'])

Expand Down Expand Up @@ -506,10 +506,10 @@ def _combine_tree(current_tree, new_tree, inverse_mass_matrix, going_right, rng,
turning = current_tree.turning

transition = random.bernoulli(rng, transition_prob)
z_proposal, z_proposal_pe, z_proposal_grad = cond(
z_proposal, z_proposal_pe, z_proposal_grad, z_proposal_energy = cond(
transition,
new_tree, lambda tree: (tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad),
current_tree, lambda tree: (tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad)
new_tree, lambda tree: (tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, tree.z_proposal_energy),
current_tree, lambda tree: (tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, tree.z_proposal_energy)
)

tree_depth = current_tree.depth + 1
Expand All @@ -520,7 +520,7 @@ def _combine_tree(current_tree, new_tree, inverse_mass_matrix, going_right, rng,
num_proposals = current_tree.num_proposals + new_tree.num_proposals

return TreeInfo(z_left, r_left, z_left_grad, z_right, r_right, r_right_grad,
z_proposal, z_proposal_pe, z_proposal_grad,
z_proposal, z_proposal_pe, z_proposal_grad, z_proposal_energy,
tree_depth, tree_weight, r_sum, turning, diverging,
sum_accept_probs, num_proposals)

Expand All @@ -543,7 +543,7 @@ def _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, st
diverging = delta_energy > max_delta_energy
accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0)
return TreeInfo(z_new, r_new, z_new_grad, z_new, r_new, z_new_grad,
z_new, potential_energy_new, z_new_grad,
z_new, potential_energy_new, z_new_grad, energy_new,
depth=0, weight=tree_weight, r_sum=r_new, turning=False,
diverging=diverging, sum_accept_probs=accept_prob, num_proposals=1)

Expand Down Expand Up @@ -654,7 +654,7 @@ def _body_fn(state):
# update depth and turning condition
return TreeInfo(tree.z_left, tree.r_left, tree.z_left_grad,
tree.z_right, tree.r_right, tree.z_right_grad,
tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad,
tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, tree.z_proposal_energy,
depth, tree.weight, tree.r_sum, turning, tree.diverging,
tree.sum_accept_probs, tree.num_proposals)

Expand Down Expand Up @@ -689,7 +689,7 @@ def build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, ste
r_ckpts = np.zeros((max_tree_depth, inverse_mass_matrix.shape[-1]))
r_sum_ckpts = np.zeros((max_tree_depth, inverse_mass_matrix.shape[-1]))

tree = TreeInfo(z, r, z_grad, z, r, z_grad, z, potential_energy, z_grad,
tree = TreeInfo(z, r, z_grad, z, r, z_grad, z, potential_energy, z_grad, energy_current,
depth=0, weight=0., r_sum=r, turning=False, diverging=False,
sum_accept_probs=0., num_proposals=0)

Expand Down
26 changes: 14 additions & 12 deletions numpyro/mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
)
from numpyro.util import cond, copy_docs_from, fori_collect, fori_loop, identity

HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'num_steps', 'accept_prob',
HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob',
'mean_accept_prob', 'diverging', 'adapt_state', 'rng'])
"""
A :func:`~collections.namedtuple` consisting of the following fields:
Expand All @@ -37,6 +37,7 @@
the posterior) at latent sites.
- **z_grad** - Gradient of potential energy w.r.t. latent sample sites.
- **potential_energy** - Potential energy computed at the given value of ``z``.
- **energy** - Sum of potential energy and kinetic energy of the current state.
- **num_steps** - Number of steps in the Hamiltonian trajectory (for diagnostics).
- **accept_prob** - Acceptance probability of the proposal. Note that ``z``
does not correspond to the proposal if it is rejected.
Expand Down Expand Up @@ -236,8 +237,9 @@ def init_kernel(init_params,
wa_state = wa_init(z, rng_wa, step_size, mass_matrix_size=np.size(z_flat))
r = momentum_generator(wa_state.mass_matrix_sqrt, rng)
vv_state = vv_init(z, r)
hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., 0.,
False, wa_state, rng_hmc)
energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r)
hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy,
0, 0., 0., False, wa_state, rng_hmc)

# TODO: Remove; this should be the responsibility of the MCMC class.
if run_warmup and num_warmup > 0:
Expand All @@ -263,10 +265,10 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, rng):
accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0)
diverging = delta_energy > max_delta_energy
transition = random.bernoulli(rng, accept_prob)
vv_state = cond(transition,
vv_state_new, lambda state: state,
vv_state, lambda state: state)
return vv_state, num_steps, accept_prob, diverging
vv_state, energy = cond(transition,
(vv_state_new, energy_new), lambda args: args,
(vv_state, energy_old), lambda args: args)
return vv_state, energy, num_steps, accept_prob, diverging

def _nuts_next(step_size, inverse_mass_matrix, vv_state, rng):
binary_tree = build_tree(vv_update, kinetic_fn, vv_state,
Expand All @@ -279,7 +281,7 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, rng):
r=vv_state.r,
potential_energy=binary_tree.z_proposal_pe,
z_grad=binary_tree.z_proposal_grad)
return vv_state, num_steps, accept_prob, binary_tree.diverging
return vv_state, binary_tree.z_proposal_energy, num_steps, accept_prob, binary_tree.diverging

_next = _nuts_next if algo == 'NUTS' else _hmc_next

Expand All @@ -295,9 +297,9 @@ def sample_kernel(hmc_state):
rng, rng_momentum, rng_transition = random.split(hmc_state.rng, 3)
r = momentum_generator(hmc_state.adapt_state.mass_matrix_sqrt, rng_momentum)
vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad)
vv_state, num_steps, accept_prob, diverging = _next(hmc_state.adapt_state.step_size,
hmc_state.adapt_state.inverse_mass_matrix,
vv_state, rng_transition)
vv_state, energy, num_steps, accept_prob, diverging = _next(hmc_state.adapt_state.step_size,
hmc_state.adapt_state.inverse_mass_matrix,
vv_state, rng_transition)
# not update adapt_state after warmup phase
adapt_state = cond(hmc_state.i < wa_steps,
(hmc_state.i, accept_prob, vv_state.z, hmc_state.adapt_state),
Expand All @@ -309,7 +311,7 @@ def sample_kernel(hmc_state):
n = np.where(hmc_state.i < wa_steps, itr, itr - wa_steps)
mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n

return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, num_steps,
return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps,
accept_prob, mean_accept_prob, diverging, adapt_state, rng)

# Make `init_kernel` and `sample_kernel` visible from the global scope once
Expand Down
2 changes: 1 addition & 1 deletion test/test_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def model(labels):
init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), model, labels)
samples = mcmc(warmup_steps, num_samples, init_params, sampler='hmc', algo=algo,
potential_fn=potential_fn, trajectory_length=10, constrain_fn=constrain_fn)
assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.21)
assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.22)

if 'JAX_ENABLE_x64' in os.environ:
assert samples['coefs'].dtype == np.float64
Expand Down