diff --git a/numpyro/hmc_util.py b/numpyro/hmc_util.py index cb4a0a32f..6df731ffd 100644 --- a/numpyro/hmc_util.py +++ b/numpyro/hmc_util.py @@ -22,7 +22,7 @@ TreeInfo = namedtuple('TreeInfo', ['z_left', 'r_left', 'z_left_grad', 'z_right', 'r_right', 'z_right_grad', - 'z_proposal', 'z_proposal_pe', 'z_proposal_grad', + 'z_proposal', 'z_proposal_pe', 'z_proposal_grad', 'z_proposal_energy', 'depth', 'weight', 'r_sum', 'turning', 'diverging', 'sum_accept_probs', 'num_proposals']) @@ -506,10 +506,10 @@ def _combine_tree(current_tree, new_tree, inverse_mass_matrix, going_right, rng, turning = current_tree.turning transition = random.bernoulli(rng, transition_prob) - z_proposal, z_proposal_pe, z_proposal_grad = cond( + z_proposal, z_proposal_pe, z_proposal_grad, z_proposal_energy = cond( transition, - new_tree, lambda tree: (tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad), - current_tree, lambda tree: (tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad) + new_tree, lambda tree: (tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, tree.z_proposal_energy), + current_tree, lambda tree: (tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, tree.z_proposal_energy) ) tree_depth = current_tree.depth + 1 @@ -520,7 +520,7 @@ def _combine_tree(current_tree, new_tree, inverse_mass_matrix, going_right, rng, num_proposals = current_tree.num_proposals + new_tree.num_proposals return TreeInfo(z_left, r_left, z_left_grad, z_right, r_right, r_right_grad, - z_proposal, z_proposal_pe, z_proposal_grad, + z_proposal, z_proposal_pe, z_proposal_grad, z_proposal_energy, tree_depth, tree_weight, r_sum, turning, diverging, sum_accept_probs, num_proposals) @@ -543,7 +543,7 @@ def _build_basetree(vv_update, kinetic_fn, z, r, z_grad, inverse_mass_matrix, st diverging = delta_energy > max_delta_energy accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0) return TreeInfo(z_new, r_new, z_new_grad, z_new, r_new, z_new_grad, - z_new, potential_energy_new, z_new_grad, + z_new, potential_energy_new, z_new_grad, energy_new, depth=0, weight=tree_weight, r_sum=r_new, turning=False, diverging=diverging, sum_accept_probs=accept_prob, num_proposals=1) @@ -654,7 +654,7 @@ def _body_fn(state): # update depth and turning condition return TreeInfo(tree.z_left, tree.r_left, tree.z_left_grad, tree.z_right, tree.r_right, tree.z_right_grad, - tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, + tree.z_proposal, tree.z_proposal_pe, tree.z_proposal_grad, tree.z_proposal_energy, depth, tree.weight, tree.r_sum, turning, tree.diverging, tree.sum_accept_probs, tree.num_proposals) @@ -689,7 +689,7 @@ def build_tree(verlet_update, kinetic_fn, verlet_state, inverse_mass_matrix, ste r_ckpts = np.zeros((max_tree_depth, inverse_mass_matrix.shape[-1])) r_sum_ckpts = np.zeros((max_tree_depth, inverse_mass_matrix.shape[-1])) - tree = TreeInfo(z, r, z_grad, z, r, z_grad, z, potential_energy, z_grad, + tree = TreeInfo(z, r, z_grad, z, r, z_grad, z, potential_energy, z_grad, energy_current, depth=0, weight=0., r_sum=r, turning=False, diverging=False, sum_accept_probs=0., num_proposals=0) diff --git a/numpyro/mcmc.py b/numpyro/mcmc.py index 78410fe60..088990bc8 100644 --- a/numpyro/mcmc.py +++ b/numpyro/mcmc.py @@ -27,7 +27,7 @@ ) from numpyro.util import cond, copy_docs_from, fori_collect, fori_loop, identity -HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'num_steps', 'accept_prob', +HMCState = namedtuple('HMCState', ['i', 'z', 'z_grad', 'potential_energy', 'energy', 'num_steps', 'accept_prob', 'mean_accept_prob', 'diverging', 'adapt_state', 'rng']) """ A :func:`~collections.namedtuple` consisting of the following fields: @@ -37,6 +37,7 @@ the posterior) at latent sites. - **z_grad** - Gradient of potential energy w.r.t. latent sample sites. - **potential_energy** - Potential energy computed at the given value of ``z``. + - **energy** - Sum of potential energy and kinetic energy of the current state. - **num_steps** - Number of steps in the Hamiltonian trajectory (for diagnostics). - **accept_prob** - Acceptance probability of the proposal. Note that ``z`` does not correspond to the proposal if it is rejected. @@ -236,8 +237,9 @@ def init_kernel(init_params, wa_state = wa_init(z, rng_wa, step_size, mass_matrix_size=np.size(z_flat)) r = momentum_generator(wa_state.mass_matrix_sqrt, rng) vv_state = vv_init(z, r) - hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, 0, 0., 0., - False, wa_state, rng_hmc) + energy = kinetic_fn(wa_state.inverse_mass_matrix, vv_state.r) + hmc_state = HMCState(0, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, + 0, 0., 0., False, wa_state, rng_hmc) # TODO: Remove; this should be the responsibility of the MCMC class. if run_warmup and num_warmup > 0: @@ -263,10 +265,10 @@ def _hmc_next(step_size, inverse_mass_matrix, vv_state, rng): accept_prob = np.clip(np.exp(-delta_energy), a_max=1.0) diverging = delta_energy > max_delta_energy transition = random.bernoulli(rng, accept_prob) - vv_state = cond(transition, - vv_state_new, lambda state: state, - vv_state, lambda state: state) - return vv_state, num_steps, accept_prob, diverging + vv_state, energy = cond(transition, + (vv_state_new, energy_new), lambda args: args, + (vv_state, energy_old), lambda args: args) + return vv_state, energy, num_steps, accept_prob, diverging def _nuts_next(step_size, inverse_mass_matrix, vv_state, rng): binary_tree = build_tree(vv_update, kinetic_fn, vv_state, @@ -279,7 +281,7 @@ def _nuts_next(step_size, inverse_mass_matrix, vv_state, rng): r=vv_state.r, potential_energy=binary_tree.z_proposal_pe, z_grad=binary_tree.z_proposal_grad) - return vv_state, num_steps, accept_prob, binary_tree.diverging + return vv_state, binary_tree.z_proposal_energy, num_steps, accept_prob, binary_tree.diverging _next = _nuts_next if algo == 'NUTS' else _hmc_next @@ -295,9 +297,9 @@ def sample_kernel(hmc_state): rng, rng_momentum, rng_transition = random.split(hmc_state.rng, 3) r = momentum_generator(hmc_state.adapt_state.mass_matrix_sqrt, rng_momentum) vv_state = IntegratorState(hmc_state.z, r, hmc_state.potential_energy, hmc_state.z_grad) - vv_state, num_steps, accept_prob, diverging = _next(hmc_state.adapt_state.step_size, - hmc_state.adapt_state.inverse_mass_matrix, - vv_state, rng_transition) + vv_state, energy, num_steps, accept_prob, diverging = _next(hmc_state.adapt_state.step_size, + hmc_state.adapt_state.inverse_mass_matrix, + vv_state, rng_transition) # not update adapt_state after warmup phase adapt_state = cond(hmc_state.i < wa_steps, (hmc_state.i, accept_prob, vv_state.z, hmc_state.adapt_state), @@ -309,7 +311,7 @@ def sample_kernel(hmc_state): n = np.where(hmc_state.i < wa_steps, itr, itr - wa_steps) mean_accept_prob = hmc_state.mean_accept_prob + (accept_prob - hmc_state.mean_accept_prob) / n - return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, num_steps, + return HMCState(itr, vv_state.z, vv_state.z_grad, vv_state.potential_energy, energy, num_steps, accept_prob, mean_accept_prob, diverging, adapt_state, rng) # Make `init_kernel` and `sample_kernel` visible from the global scope once diff --git a/test/test_mcmc.py b/test/test_mcmc.py index 712d58f6b..5e7f6e39e 100644 --- a/test/test_mcmc.py +++ b/test/test_mcmc.py @@ -31,7 +31,7 @@ def model(labels): init_params, potential_fn, constrain_fn = initialize_model(random.PRNGKey(2), model, labels) samples = mcmc(warmup_steps, num_samples, init_params, sampler='hmc', algo=algo, potential_fn=potential_fn, trajectory_length=10, constrain_fn=constrain_fn) - assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.21) + assert_allclose(np.mean(samples['coefs'], 0), true_coefs, atol=0.22) if 'JAX_ENABLE_x64' in os.environ: assert samples['coefs'].dtype == np.float64