diff --git a/src/stan/mcmc/hmc/nuts/base_nuts.hpp b/src/stan/mcmc/hmc/nuts/base_nuts.hpp index 04444cd8518..5c27cd3025f 100644 --- a/src/stan/mcmc/hmc/nuts/base_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/base_nuts.hpp @@ -81,56 +81,86 @@ namespace stan { this->hamiltonian_.sample_p(this->z_, this->rand_int_); this->hamiltonian_.init(this->z_, logger); - ps_point z_plus(this->z_); - ps_point z_minus(z_plus); + ps_point z_fwd(this->z_); // State at forward end of trajectory + ps_point z_bck(z_fwd); // State at backward end of trajectory - ps_point z_sample(z_plus); - ps_point z_propose(z_plus); + ps_point z_sample(z_fwd); + ps_point z_propose(z_fwd); - Eigen::VectorXd p_sharp_plus = this->hamiltonian_.dtau_dp(this->z_); - Eigen::VectorXd p_sharp_dummy = p_sharp_plus; - Eigen::VectorXd p_sharp_minus = p_sharp_plus; - Eigen::VectorXd rho = this->z_.p; + // Momentum and sharp momentum at forward end of forward subtree + Eigen::VectorXd p_fwd_fwd = this->z_.p; + Eigen::VectorXd p_sharp_fwd_fwd = this->hamiltonian_.dtau_dp(this->z_); + // Momentum and sharp momentum at backward end of forward subtree + Eigen::VectorXd p_fwd_bck = this->z_.p; + Eigen::VectorXd p_sharp_fwd_bck = p_sharp_fwd_fwd; + + // Momentum and sharp momentum at forward end of backward subtree + Eigen::VectorXd p_bck_fwd = this->z_.p; + Eigen::VectorXd p_sharp_bck_fwd = p_sharp_fwd_fwd; + + // Momentum and sharp momentum at backward end of backward subtree + Eigen::VectorXd p_bck_bck = this->z_.p; + Eigen::VectorXd p_sharp_bck_bck = p_sharp_fwd_fwd; + + // Integrated momenta along trajectory + Eigen::VectorXd rho = this->z_.p.transpose(); + + // Log sum of state weights (offset by H0) along trajectory double log_sum_weight = 0; // log(exp(H0 - H0)) double H0 = this->hamiltonian_.H(this->z_); int n_leapfrog = 0; double sum_metro_prob = 0; - // Build a trajectory until the NUTS criterion is no longer satisfied + // Build a trajectory until the no-u-turn + // criterion is no longer satisfied this->depth_ = 0; this->divergent_ = false; while (this->depth_ < this->max_depth_) { // Build a new subtree in a random direction - Eigen::VectorXd rho_subtree = Eigen::VectorXd::Zero(rho.size()); + Eigen::VectorXd rho_fwd = Eigen::VectorXd::Zero(rho.size()); + Eigen::VectorXd rho_bck = Eigen::VectorXd::Zero(rho.size()); + bool valid_subtree = false; double log_sum_weight_subtree = -std::numeric_limits::infinity(); if (this->rand_uniform_() > 0.5) { - this->z_.ps_point::operator=(z_plus); + // Extend the current trajectory forward + this->z_.ps_point::operator=(z_fwd); + rho_bck = rho; + p_bck_fwd = p_fwd_fwd; + p_sharp_bck_fwd = p_sharp_fwd_fwd; + valid_subtree = build_tree(this->depth_, z_propose, - p_sharp_dummy, p_sharp_plus, rho_subtree, + p_sharp_fwd_bck, p_sharp_fwd_fwd, + rho_fwd, p_fwd_bck, p_fwd_fwd, H0, 1, n_leapfrog, log_sum_weight_subtree, sum_metro_prob, logger); - z_plus.ps_point::operator=(this->z_); + z_fwd.ps_point::operator=(this->z_); } else { - this->z_.ps_point::operator=(z_minus); + // Extend the current trajectory backwards + this->z_.ps_point::operator=(z_bck); + rho_fwd = rho; + p_fwd_bck = p_bck_bck; + p_sharp_fwd_bck = p_sharp_bck_bck; + valid_subtree = build_tree(this->depth_, z_propose, - p_sharp_dummy, p_sharp_minus, rho_subtree, + p_sharp_bck_fwd, p_sharp_bck_bck, + rho_bck, p_bck_fwd, p_bck_bck, H0, -1, n_leapfrog, log_sum_weight_subtree, sum_metro_prob, logger); - z_minus.ps_point::operator=(this->z_); + z_bck.ps_point::operator=(this->z_); } if (!valid_subtree) break; - // Sample from an accepted subtree + // Sample from accepted subtree ++(this->depth_); if (log_sum_weight_subtree > log_sum_weight) { @@ -145,9 +175,30 @@ namespace stan { log_sum_weight = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); - // Break when NUTS criterion is no longer satisfied - rho += rho_subtree; - if (!compute_criterion(p_sharp_minus, p_sharp_plus, rho)) + // Break when no-u-turn criterion is no longer satisfied + rho = rho_bck + rho_fwd; + + // Demand satisfaction around merged subtrees + bool persist_criterion = + compute_criterion(p_sharp_bck_bck, + p_sharp_fwd_fwd, + rho); + + // Demand satisfaction between subtrees + Eigen::VectorXd rho_extended = rho_bck + p_fwd_bck; + + persist_criterion &= + compute_criterion(p_sharp_bck_bck, + p_sharp_fwd_bck, + rho_extended); + + rho_extended = rho_fwd + p_bck_fwd; + persist_criterion &= + compute_criterion(p_sharp_bck_fwd, + p_sharp_fwd_fwd, + rho_extended); + + if (!persist_criterion) break; } @@ -193,9 +244,11 @@ namespace stan { * * @param depth Depth of the desired subtree * @param z_propose State proposed from subtree - * @param p_sharp_left p_sharp from left boundary of returned tree - * @param p_sharp_right p_sharp from the right boundary of returned tree + * @param p_sharp_beg Sharp momentum at beginning of new tree + * @param p_sharp_end Sharp momentum at end of new tree * @param rho Summed momentum across trajectory + * @param p_beg Momentum at beginning of returned tree + * @param p_end Momentum at end of returned tree * @param H0 Hamiltonian of initial state * @param sign Direction in time to built subtree * @param n_leapfrog Summed number of leapfrog evaluations @@ -204,9 +257,11 @@ namespace stan { * @param logger Logger for messages */ bool build_tree(int depth, ps_point& z_propose, - Eigen::VectorXd& p_sharp_left, - Eigen::VectorXd& p_sharp_right, + Eigen::VectorXd& p_sharp_beg, + Eigen::VectorXd& p_sharp_end, Eigen::VectorXd& rho, + Eigen::VectorXd& p_beg, + Eigen::VectorXd& p_end, double H0, double sign, int& n_leapfrog, double& log_sum_weight, double& sum_metro_prob, callbacks::logger& logger) { @@ -231,63 +286,90 @@ namespace stan { sum_metro_prob += std::exp(H0 - h); z_propose = this->z_; - rho += this->z_.p; - p_sharp_left = this->hamiltonian_.dtau_dp(this->z_); - p_sharp_right = p_sharp_left; + p_sharp_beg = this->hamiltonian_.dtau_dp(this->z_); + p_sharp_end = p_sharp_beg; + + rho += this->z_.p; + p_beg = this->z_.p; + p_end = p_beg; return !this->divergent_; } // General recursion - Eigen::VectorXd p_sharp_dummy(this->z_.p.size()); - // Build the left subtree - double log_sum_weight_left = -std::numeric_limits::infinity(); - Eigen::VectorXd rho_left = Eigen::VectorXd::Zero(rho.size()); + // Build the initial subtree + double log_sum_weight_init = -std::numeric_limits::infinity(); + + // Momentum and sharp momentum at end of the initial subtree + Eigen::VectorXd p_init_end(this->z_.p.size()); + Eigen::VectorXd p_sharp_init_end(this->z_.p.size()); + + Eigen::VectorXd rho_init = Eigen::VectorXd::Zero(rho.size()); - bool valid_left + bool valid_init = build_tree(depth - 1, z_propose, - p_sharp_left, p_sharp_dummy, rho_left, + p_sharp_beg, p_sharp_init_end, + rho_init, p_beg, p_init_end, H0, sign, n_leapfrog, - log_sum_weight_left, sum_metro_prob, + log_sum_weight_init, sum_metro_prob, logger); - if (!valid_left) return false; + if (!valid_init) return false; - // Build the right subtree - ps_point z_propose_right(this->z_); + // Build the final subtree + ps_point z_propose_final(this->z_); - double log_sum_weight_right = -std::numeric_limits::infinity(); - Eigen::VectorXd rho_right = Eigen::VectorXd::Zero(rho.size()); + double log_sum_weight_final = -std::numeric_limits::infinity(); - bool valid_right - = build_tree(depth - 1, z_propose_right, - p_sharp_dummy, p_sharp_right, rho_right, + // Momentum and sharp momentum at beginning of the final subtree + Eigen::VectorXd p_final_beg(this->z_.p.size()); + Eigen::VectorXd p_sharp_final_beg(this->z_.p.size()); + + Eigen::VectorXd rho_final = Eigen::VectorXd::Zero(rho.size()); + + bool valid_final + = build_tree(depth - 1, z_propose_final, + p_sharp_final_beg, p_sharp_end, + rho_final, p_final_beg, p_end, H0, sign, n_leapfrog, - log_sum_weight_right, sum_metro_prob, + log_sum_weight_final, sum_metro_prob, logger); - if (!valid_right) return false; + if (!valid_final) return false; // Multinomial sample from right subtree double log_sum_weight_subtree - = math::log_sum_exp(log_sum_weight_left, log_sum_weight_right); + = math::log_sum_exp(log_sum_weight_init, log_sum_weight_final); log_sum_weight = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); - if (log_sum_weight_right > log_sum_weight_subtree) { - z_propose = z_propose_right; + if (log_sum_weight_final > log_sum_weight_subtree) { + z_propose = z_propose_final; } else { double accept_prob - = std::exp(log_sum_weight_right - log_sum_weight_subtree); + = std::exp(log_sum_weight_final - log_sum_weight_subtree); if (this->rand_uniform_() < accept_prob) - z_propose = z_propose_right; + z_propose = z_propose_final; } - Eigen::VectorXd rho_subtree = rho_left + rho_right; + Eigen::VectorXd rho_subtree = rho_init + rho_final; rho += rho_subtree; - return compute_criterion(p_sharp_left, p_sharp_right, rho_subtree); + // Demand satisfaction around merged subtrees + bool persist_criterion = + compute_criterion(p_sharp_beg, p_sharp_end, rho_subtree); + + // Demand satisfaction between subtrees + rho_subtree = rho_init + p_final_beg; + persist_criterion &= + compute_criterion(p_sharp_beg, p_sharp_final_beg, rho_subtree); + + rho_subtree = rho_final + p_init_end; + persist_criterion &= + compute_criterion(p_sharp_init_end, p_sharp_end, rho_subtree); + + return persist_criterion; } int depth_; diff --git a/src/test/performance/logistic_test.cpp b/src/test/performance/logistic_test.cpp index 0efe715a99c..de1b7a7607b 100644 --- a/src/test/performance/logistic_test.cpp +++ b/src/test/performance/logistic_test.cpp @@ -108,31 +108,31 @@ TEST_F(performance, values_from_tagged_version) { << "last tagged version, 2.17.0, had " << N_values << " elements"; std::vector first_run = last_draws_per_run[0]; - EXPECT_FLOAT_EQ(-65.781998, first_run[0]) + EXPECT_FLOAT_EQ(-65.216301, first_run[0]) << "lp__: index 0"; - EXPECT_FLOAT_EQ(1.0, first_run[1]) + EXPECT_FLOAT_EQ(0.91851199, first_run[1]) << "accept_stat__: index 1"; - EXPECT_FLOAT_EQ(0.76853198, first_run[2]) + EXPECT_FLOAT_EQ(0.76885802, first_run[2]) << "stepsize__: index 2"; EXPECT_FLOAT_EQ(2, first_run[3]) << "treedepth__: index 3"; - EXPECT_FLOAT_EQ(7, first_run[4]) + EXPECT_FLOAT_EQ(3, first_run[4]) << "n_leapfrog__: index 4"; EXPECT_FLOAT_EQ(0, first_run[5]) << "divergent__: index 5"; - EXPECT_FLOAT_EQ(66.6695, first_run[6]) + EXPECT_FLOAT_EQ(66.696503, first_run[6]) << "energy__: index 6"; - EXPECT_FLOAT_EQ(1.55186, first_run[7]) + EXPECT_FLOAT_EQ(1.3577, first_run[7]) << "beta.1: index 7"; - EXPECT_FLOAT_EQ(-0.52400702, first_run[8]) + EXPECT_FLOAT_EQ(-0.51189202, first_run[8]) << "beta.2: index 8"; matches_tagged_version = !HasNonfatalFailure(); diff --git a/src/test/unit/mcmc/hmc/mock_hmc.hpp b/src/test/unit/mcmc/hmc/mock_hmc.hpp index 33c9d29028c..0a79bc9233d 100644 --- a/src/test/unit/mcmc/hmc/mock_hmc.hpp +++ b/src/test/unit/mcmc/hmc/mock_hmc.hpp @@ -64,7 +64,7 @@ namespace stan { // Ensures that NUTS non-termination criterion is always true Eigen::VectorXd dtau_dp(ps_point& z) { - return Eigen::VectorXd::Ones(this->model_.num_params_r()); + return z.q; } Eigen::VectorXd dphi_dq(ps_point& z, diff --git a/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp b/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp index de309c11460..798f8e347e8 100644 --- a/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp +++ b/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp @@ -20,13 +20,18 @@ namespace stan { mock_nuts(const mock_model &m, rng_t& rng) : base_nuts(m, rng) { } + + bool compute_criterion(Eigen::VectorXd& p_sharp_minus, + Eigen::VectorXd& p_sharp_plus, + Eigen::VectorXd& rho) { + return true; + } }; class rho_inspector_mock_nuts: public base_nuts { - public: std::vector rho_values; rho_inspector_mock_nuts(const mock_model &m, rng_t& rng) @@ -41,6 +46,26 @@ namespace stan { } }; + class edge_inspector_mock_nuts: public base_nuts { + public: + std::vector p_sharp_minus_values; + std::vector p_sharp_plus_values; + edge_inspector_mock_nuts(const mock_model &m, rng_t& rng) + : base_nuts(m, rng) + { } + + bool compute_criterion(Eigen::VectorXd& p_sharp_minus, + Eigen::VectorXd& p_sharp_plus, + Eigen::VectorXd& rho) { + p_sharp_minus_values.push_back(p_sharp_minus(0)); + p_sharp_plus_values.push_back(p_sharp_plus(0)); + return true; + } + }; + // Mock Hamiltonian template class divergent_hamiltonian @@ -149,9 +174,12 @@ TEST(McmcNutsBaseNuts, build_tree_test) { stan::mcmc::ps_point z_propose(model_size); - Eigen::VectorXd p_sharp_left = Eigen::VectorXd::Zero(model_size); - Eigen::VectorXd p_sharp_right = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_begin = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_sharp_begin = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_end = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_sharp_end = Eigen::VectorXd::Zero(model_size); Eigen::VectorXd rho = z_init.p; + double log_sum_weight = -std::numeric_limits::infinity(); double H0 = -0.1; @@ -170,15 +198,18 @@ TEST(McmcNutsBaseNuts, build_tree_test) { stan::callbacks::stream_logger logger(debug, info, warn, error, fatal); bool valid_subtree = sampler.build_tree(3, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); EXPECT_TRUE(valid_subtree); EXPECT_EQ(init_momentum * (n_leapfrog + 1), rho(0)); - EXPECT_EQ(1, p_sharp_left(0)); - EXPECT_EQ(1, p_sharp_right(0)); + EXPECT_EQ(1.5, p_begin(0)); + EXPECT_EQ(1.5, p_sharp_begin(0)); + EXPECT_EQ(1.5, p_end(0)); + EXPECT_EQ(12, p_sharp_end(0)); EXPECT_EQ(8 * init_momentum, sampler.z().q(0)); EXPECT_EQ(init_momentum, sampler.z().p(0)); @@ -207,9 +238,12 @@ TEST(McmcNutsBaseNuts, rho_aggregation_test) { stan::mcmc::ps_point z_propose(model_size); - Eigen::VectorXd p_sharp_left = Eigen::VectorXd::Zero(model_size); - Eigen::VectorXd p_sharp_right = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_begin = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_sharp_begin = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_end = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_sharp_end = Eigen::VectorXd::Zero(model_size); Eigen::VectorXd rho = z_init.p; + double log_sum_weight = -std::numeric_limits::infinity(); double H0 = -0.1; @@ -228,18 +262,39 @@ TEST(McmcNutsBaseNuts, rho_aggregation_test) { stan::callbacks::stream_logger logger(debug, info, warn, error, fatal); sampler.build_tree(3, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - EXPECT_EQ(7, sampler.rho_values.size()); - EXPECT_EQ(2 * init_momentum, sampler.rho_values.at(0)); - EXPECT_EQ(2 * init_momentum, sampler.rho_values.at(1)); - EXPECT_EQ(4 * init_momentum, sampler.rho_values.at(2)); - EXPECT_EQ(2 * init_momentum, sampler.rho_values.at(3)); - EXPECT_EQ(2 * init_momentum, sampler.rho_values.at(4)); - EXPECT_EQ(4 * init_momentum, sampler.rho_values.at(5)); - EXPECT_EQ(8 * init_momentum, sampler.rho_values.at(6)); + EXPECT_EQ(7 * 3, sampler.rho_values.size()); + + // Trajectory component spanning rhos + EXPECT_EQ(2 * init_momentum, sampler.rho_values[0]); + EXPECT_EQ(2 * init_momentum, sampler.rho_values[3]); + EXPECT_EQ(4 * init_momentum, sampler.rho_values[6]); + EXPECT_EQ(2 * init_momentum, sampler.rho_values[9]); + EXPECT_EQ(2 * init_momentum, sampler.rho_values[12]); + EXPECT_EQ(4 * init_momentum, sampler.rho_values[15]); + EXPECT_EQ(8 * init_momentum, sampler.rho_values[18]); + + // Cross trajectory component rhos + EXPECT_EQ(2 * init_momentum, sampler.rho_values[1]); + EXPECT_EQ(2 * init_momentum, sampler.rho_values[4]); + EXPECT_EQ(3 * init_momentum, sampler.rho_values[7]); + EXPECT_EQ(2 * init_momentum, sampler.rho_values[10]); + EXPECT_EQ(2 * init_momentum, sampler.rho_values[13]); + EXPECT_EQ(3 * init_momentum, sampler.rho_values[16]); + EXPECT_EQ(5 * init_momentum, sampler.rho_values[19]); + + EXPECT_EQ(2 * init_momentum, sampler.rho_values[2]); + EXPECT_EQ(2 * init_momentum, sampler.rho_values[5]); + EXPECT_EQ(3 * init_momentum, sampler.rho_values[8]); + EXPECT_EQ(2 * init_momentum, sampler.rho_values[11]); + EXPECT_EQ(2 * init_momentum, sampler.rho_values[14]); + EXPECT_EQ(3 * init_momentum, sampler.rho_values[17]); + EXPECT_EQ(5 * init_momentum, sampler.rho_values[20]); + } TEST(McmcNutsBaseNuts, divergence_test) { @@ -255,9 +310,12 @@ TEST(McmcNutsBaseNuts, divergence_test) { stan::mcmc::ps_point z_propose(model_size); - Eigen::VectorXd p_sharp_left = Eigen::VectorXd::Zero(model_size); - Eigen::VectorXd p_sharp_right = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_begin = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_sharp_begin = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_end = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_sharp_end = Eigen::VectorXd::Zero(model_size); Eigen::VectorXd rho = z_init.p; + double log_sum_weight = -std::numeric_limits::infinity(); double H0 = -0.1; @@ -279,7 +337,8 @@ TEST(McmcNutsBaseNuts, divergence_test) { sampler.z().V = -750; valid_subtree = sampler.build_tree(0, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); @@ -288,7 +347,8 @@ TEST(McmcNutsBaseNuts, divergence_test) { sampler.z().V = -250; valid_subtree = sampler.build_tree(0, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); @@ -298,7 +358,8 @@ TEST(McmcNutsBaseNuts, divergence_test) { sampler.z().V = 750; valid_subtree = sampler.build_tree(0, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); @@ -353,3 +414,70 @@ TEST(McmcNutsBaseNuts, transition) { EXPECT_EQ("", error.str()); EXPECT_EQ("", fatal.str()); } + +TEST(McmcNutsBaseNuts, transition_egde_momenta) { + + rng_t base_rng(0); + + int model_size = 1; + double init_momentum = 1.5; + + stan::mcmc::ps_point z_init(model_size); + z_init.q(0) = 0; + z_init.p(0) = init_momentum; + + stan::mcmc::mock_model model(model_size); + stan::mcmc::edge_inspector_mock_nuts sampler(model, base_rng); + + sampler.set_max_depth(2); + + sampler.set_nominal_stepsize(1); + sampler.set_stepsize_jitter(0); + sampler.sample_stepsize(); + sampler.z() = z_init; + + std::stringstream debug, info, warn, error, fatal; + stan::callbacks::stream_logger logger(debug, info, warn, error, fatal); + + stan::mcmc::sample init_sample(z_init.q, 0, 0); + + // Transition will expand trajectory until max_depth is hit + stan::mcmc::sample s = sampler.transition(init_sample, logger); + + EXPECT_EQ(2, sampler.depth_); + EXPECT_EQ((2 << (sampler.get_max_depth() - 1)) - 1, sampler.n_leapfrog_); + EXPECT_FALSE(sampler.divergent_); + + + EXPECT_EQ(9, sampler.p_sharp_minus_values.size()); + + // Depth 0 Transition Check + EXPECT_EQ(0, sampler.p_sharp_minus_values[0]); + EXPECT_EQ(init_momentum, sampler.p_sharp_plus_values[0]); + + EXPECT_EQ(0, sampler.p_sharp_minus_values[1]); + EXPECT_EQ(init_momentum, sampler.p_sharp_plus_values[1]); + + EXPECT_EQ(0, sampler.p_sharp_minus_values[2]); + EXPECT_EQ(init_momentum, sampler.p_sharp_plus_values[2]); + + // Depth 1 Build Tree Check + EXPECT_EQ(2 * init_momentum, sampler.p_sharp_minus_values[3]); + EXPECT_EQ(3 * init_momentum, sampler.p_sharp_plus_values[3]); + + EXPECT_EQ(2 * init_momentum, sampler.p_sharp_minus_values[4]); + EXPECT_EQ(3 * init_momentum, sampler.p_sharp_plus_values[4]); + + EXPECT_EQ(2 * init_momentum, sampler.p_sharp_minus_values[5]); + EXPECT_EQ(3 * init_momentum, sampler.p_sharp_plus_values[5]); + + // Depth 1 Transition Check + EXPECT_EQ(0, sampler.p_sharp_minus_values[6]); + EXPECT_EQ(3 * init_momentum, sampler.p_sharp_plus_values[6]); + + EXPECT_EQ(0, sampler.p_sharp_minus_values[7]); + EXPECT_EQ(2 * init_momentum, sampler.p_sharp_plus_values[7]); + + EXPECT_EQ(init_momentum, sampler.p_sharp_minus_values[8]); + EXPECT_EQ(3 * init_momentum, sampler.p_sharp_plus_values[8]); +} diff --git a/src/test/unit/mcmc/hmc/nuts/softabs_nuts_test.cpp b/src/test/unit/mcmc/hmc/nuts/softabs_nuts_test.cpp index 94c09614c70..1e07a2bdd40 100644 --- a/src/test/unit/mcmc/hmc/nuts/softabs_nuts_test.cpp +++ b/src/test/unit/mcmc/hmc/nuts/softabs_nuts_test.cpp @@ -38,9 +38,12 @@ TEST(McmcSoftAbsNuts, build_tree_test) { stan::mcmc::ps_point z_propose = z_init; - Eigen::VectorXd p_sharp_left = Eigen::VectorXd::Zero(z_init.p.size()); - Eigen::VectorXd p_sharp_right = Eigen::VectorXd::Zero(z_init.p.size()); + Eigen::VectorXd p_begin = Eigen::VectorXd::Zero(3); + Eigen::VectorXd p_sharp_begin = Eigen::VectorXd::Zero(3); + Eigen::VectorXd p_end = Eigen::VectorXd::Zero(3); + Eigen::VectorXd p_sharp_end = Eigen::VectorXd::Zero(3); Eigen::VectorXd rho = z_init.p; + double log_sum_weight = -std::numeric_limits::infinity(); double H0 = -0.1; @@ -48,7 +51,8 @@ TEST(McmcSoftAbsNuts, build_tree_test) { double sum_metro_prob = 0; bool valid_subtree = sampler.build_tree(3, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); @@ -60,6 +64,22 @@ TEST(McmcSoftAbsNuts, build_tree_test) { EXPECT_FLOAT_EQ(11.679803, rho(1)); EXPECT_FLOAT_EQ(-11.679803, rho(2)); + EXPECT_FLOAT_EQ(-1.0960016, p_begin(0)); + EXPECT_FLOAT_EQ(1.0960016, p_begin(1)); + EXPECT_FLOAT_EQ(-1.0960016, p_begin(2)); + + EXPECT_FLOAT_EQ(-0.83470845, p_sharp_begin(0)); + EXPECT_FLOAT_EQ(0.83470845, p_sharp_begin(1)); + EXPECT_FLOAT_EQ(-0.83470845, p_sharp_begin(2)); + + EXPECT_FLOAT_EQ(-1.5019561, p_end(0)); + EXPECT_FLOAT_EQ(1.5019561, p_end(1)); + EXPECT_FLOAT_EQ(-1.5019561, p_end(2)); + + EXPECT_FLOAT_EQ(-1.143881, p_sharp_end(0)); + EXPECT_FLOAT_EQ(1.143881, p_sharp_end(1)); + EXPECT_FLOAT_EQ(-1.143881, p_sharp_end(2)); + EXPECT_FLOAT_EQ(0.20423166, sampler.z().q(0)); EXPECT_FLOAT_EQ(-0.20423166, sampler.z().q(1)); EXPECT_FLOAT_EQ(0.20423166, sampler.z().q(2)); @@ -110,38 +130,46 @@ TEST(McmcSoftAbsNuts, tree_boundary_test) { metric.init(z_test, logger); softabs_integrator.evolve(z_test, metric, epsilon, logger); + Eigen::VectorXd p_forward_1 = z_test.p; Eigen::VectorXd p_sharp_forward_1 = metric.dtau_dp(z_test); softabs_integrator.evolve(z_test, metric, epsilon, logger); + Eigen::VectorXd p_forward_2 = z_test.p; Eigen::VectorXd p_sharp_forward_2 = metric.dtau_dp(z_test); softabs_integrator.evolve(z_test, metric, epsilon, logger); softabs_integrator.evolve(z_test, metric, epsilon, logger); + Eigen::VectorXd p_forward_3 = z_test.p; Eigen::VectorXd p_sharp_forward_3 = metric.dtau_dp(z_test); softabs_integrator.evolve(z_test, metric, epsilon, logger); softabs_integrator.evolve(z_test, metric, epsilon, logger); softabs_integrator.evolve(z_test, metric, epsilon, logger); softabs_integrator.evolve(z_test, metric, epsilon, logger); + Eigen::VectorXd p_forward_4 = z_test.p; Eigen::VectorXd p_sharp_forward_4 = metric.dtau_dp(z_test); z_test = z_init; metric.init(z_test, logger); softabs_integrator.evolve(z_test, metric, -epsilon, logger); + Eigen::VectorXd p_backward_1 = z_test.p; Eigen::VectorXd p_sharp_backward_1 = metric.dtau_dp(z_test); softabs_integrator.evolve(z_test, metric, -epsilon, logger); + Eigen::VectorXd p_backward_2 = z_test.p; Eigen::VectorXd p_sharp_backward_2 = metric.dtau_dp(z_test); softabs_integrator.evolve(z_test, metric, -epsilon, logger); softabs_integrator.evolve(z_test, metric, -epsilon, logger); + Eigen::VectorXd p_backward_3 = z_test.p; Eigen::VectorXd p_sharp_backward_3 = metric.dtau_dp(z_test); softabs_integrator.evolve(z_test, metric, -epsilon, logger); softabs_integrator.evolve(z_test, metric, -epsilon, logger); softabs_integrator.evolve(z_test, metric, -epsilon, logger); softabs_integrator.evolve(z_test, metric, -epsilon, logger); + Eigen::VectorXd p_backward_4 = z_test.p; Eigen::VectorXd p_sharp_backward_4 = metric.dtau_dp(z_test); // Check expected tree boundaries to those dynamically geneated by NUTS @@ -153,9 +181,12 @@ TEST(McmcSoftAbsNuts, tree_boundary_test) { stan::mcmc::ps_point z_propose = z_init; - Eigen::VectorXd p_sharp_left = Eigen::VectorXd::Zero(z_init.p.size()); - Eigen::VectorXd p_sharp_right = Eigen::VectorXd::Zero(z_init.p.size()); + Eigen::VectorXd p_begin = Eigen::VectorXd::Zero(3); + Eigen::VectorXd p_sharp_begin = Eigen::VectorXd::Zero(3); + Eigen::VectorXd p_end = Eigen::VectorXd::Zero(3); + Eigen::VectorXd p_sharp_end = Eigen::VectorXd::Zero(3); Eigen::VectorXd rho = z_init.p; + double log_sum_weight = -std::numeric_limits::infinity(); double H0 = -0.1; @@ -166,121 +197,145 @@ TEST(McmcSoftAbsNuts, tree_boundary_test) { sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(0, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_forward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_forward_1(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_end(n)); + } // Depth 1 forward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(1, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_2(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_forward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_forward_2(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_2(n), p_sharp_end(n)); + } // Depth 2 forward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(2, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_3(n), p_sharp_right(n)); - + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_forward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_forward_3(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_3(n), p_sharp_end(n)); + } + // Depth 3 forward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(3, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_left(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_forward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_begin(n)); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_4(n), p_sharp_right(n)); + EXPECT_FLOAT_EQ(p_forward_4(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_4(n), p_sharp_end(n)); + } // Depth 0 backward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(0, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, -1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_backward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_backward_1(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_end(n)); + } // Depth 1 backward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(1, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, -1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_2(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_backward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_backward_2(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_2(n), p_sharp_end(n)); + } // Depth 2 backward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(2, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, -1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_3(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_backward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_backward_3(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_3(n), p_sharp_end(n)); + } // Depth 3 backward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(3, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, -1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_4(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_backward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_backward_4(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_4(n), p_sharp_end(n)); + } } TEST(McmcSoftAbsNuts, transition_test) { diff --git a/src/test/unit/mcmc/hmc/nuts/unit_e_nuts_test.cpp b/src/test/unit/mcmc/hmc/nuts/unit_e_nuts_test.cpp index e34a6d69bca..959f26d92ca 100644 --- a/src/test/unit/mcmc/hmc/nuts/unit_e_nuts_test.cpp +++ b/src/test/unit/mcmc/hmc/nuts/unit_e_nuts_test.cpp @@ -38,9 +38,12 @@ TEST(McmcUnitENuts, build_tree_test) { stan::mcmc::ps_point z_propose = z_init; - Eigen::VectorXd p_sharp_left = Eigen::VectorXd::Zero(z_init.p.size()); - Eigen::VectorXd p_sharp_right = Eigen::VectorXd::Zero(z_init.p.size()); + Eigen::VectorXd p_begin = Eigen::VectorXd::Zero(z_init.p.size()); + Eigen::VectorXd p_sharp_begin = Eigen::VectorXd::Zero(z_init.p.size()); + Eigen::VectorXd p_end = Eigen::VectorXd::Zero(z_init.p.size()); + Eigen::VectorXd p_sharp_end = Eigen::VectorXd::Zero(z_init.p.size()); Eigen::VectorXd rho = z_init.p; + double log_sum_weight = -std::numeric_limits::infinity(); double H0 = -0.1; @@ -48,7 +51,8 @@ TEST(McmcUnitENuts, build_tree_test) { double sum_metro_prob = 0; bool valid_subtree = sampler.build_tree(3, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); @@ -60,6 +64,22 @@ TEST(McmcUnitENuts, build_tree_test) { EXPECT_FLOAT_EQ(-11.401228, rho(0)); EXPECT_FLOAT_EQ(11.401228, rho(1)); EXPECT_FLOAT_EQ(-11.401228, rho(2)); + + EXPECT_FLOAT_EQ(-1.09475, p_begin(0)); + EXPECT_FLOAT_EQ(1.09475, p_begin(1)); + EXPECT_FLOAT_EQ(-1.09475, p_begin(2)); + + EXPECT_FLOAT_EQ(-1.09475, p_sharp_begin(0)); + EXPECT_FLOAT_EQ(1.09475, p_sharp_begin(1)); + EXPECT_FLOAT_EQ(-1.09475, p_sharp_begin(2)); + + EXPECT_FLOAT_EQ(-1.4131583, p_end(0)); + EXPECT_FLOAT_EQ(1.4131583, p_end(1)); + EXPECT_FLOAT_EQ(-1.4131583, p_end(2)); + + EXPECT_FLOAT_EQ(-1.4131583, p_sharp_end(0)); + EXPECT_FLOAT_EQ(1.4131583, p_sharp_end(1)); + EXPECT_FLOAT_EQ(-1.4131583, p_sharp_end(2)); EXPECT_FLOAT_EQ(-0.022019938, sampler.z().q(0)); EXPECT_FLOAT_EQ(0.022019938, sampler.z().q(1)); @@ -111,38 +131,46 @@ TEST(McmcUnitENuts, tree_boundary_test) { metric.init(z_test, logger); unit_e_integrator.evolve(z_test, metric, epsilon, logger); + Eigen::VectorXd p_forward_1 = z_test.p; Eigen::VectorXd p_sharp_forward_1 = metric.dtau_dp(z_test); unit_e_integrator.evolve(z_test, metric, epsilon, logger); + Eigen::VectorXd p_forward_2 = z_test.p; Eigen::VectorXd p_sharp_forward_2 = metric.dtau_dp(z_test); unit_e_integrator.evolve(z_test, metric, epsilon, logger); unit_e_integrator.evolve(z_test, metric, epsilon, logger); + Eigen::VectorXd p_forward_3 = z_test.p; Eigen::VectorXd p_sharp_forward_3 = metric.dtau_dp(z_test); unit_e_integrator.evolve(z_test, metric, epsilon, logger); unit_e_integrator.evolve(z_test, metric, epsilon, logger); unit_e_integrator.evolve(z_test, metric, epsilon, logger); unit_e_integrator.evolve(z_test, metric, epsilon, logger); + Eigen::VectorXd p_forward_4 = z_test.p; Eigen::VectorXd p_sharp_forward_4 = metric.dtau_dp(z_test); z_test = z_init; metric.init(z_test, logger); unit_e_integrator.evolve(z_test, metric, -epsilon, logger); + Eigen::VectorXd p_backward_1 = z_test.p; Eigen::VectorXd p_sharp_backward_1 = metric.dtau_dp(z_test); unit_e_integrator.evolve(z_test, metric, -epsilon, logger); + Eigen::VectorXd p_backward_2 = z_test.p; Eigen::VectorXd p_sharp_backward_2 = metric.dtau_dp(z_test); unit_e_integrator.evolve(z_test, metric, -epsilon, logger); unit_e_integrator.evolve(z_test, metric, -epsilon, logger); + Eigen::VectorXd p_backward_3 = z_test.p; Eigen::VectorXd p_sharp_backward_3 = metric.dtau_dp(z_test); unit_e_integrator.evolve(z_test, metric, -epsilon, logger); unit_e_integrator.evolve(z_test, metric, -epsilon, logger); unit_e_integrator.evolve(z_test, metric, -epsilon, logger); unit_e_integrator.evolve(z_test, metric, -epsilon, logger); + Eigen::VectorXd p_backward_4 = z_test.p; Eigen::VectorXd p_sharp_backward_4 = metric.dtau_dp(z_test); // Check expected tree boundaries to those dynamically geneated by NUTS @@ -154,9 +182,12 @@ TEST(McmcUnitENuts, tree_boundary_test) { stan::mcmc::ps_point z_propose = z_init; - Eigen::VectorXd p_sharp_left = Eigen::VectorXd::Zero(z_init.p.size()); - Eigen::VectorXd p_sharp_right = Eigen::VectorXd::Zero(z_init.p.size()); + Eigen::VectorXd p_begin = Eigen::VectorXd::Zero(z_init.p.size()); + Eigen::VectorXd p_sharp_begin = Eigen::VectorXd::Zero(z_init.p.size()); + Eigen::VectorXd p_end = Eigen::VectorXd::Zero(z_init.p.size()); + Eigen::VectorXd p_sharp_end = Eigen::VectorXd::Zero(z_init.p.size()); Eigen::VectorXd rho = z_init.p; + double log_sum_weight = -std::numeric_limits::infinity(); double H0 = -0.1; @@ -167,121 +198,146 @@ TEST(McmcUnitENuts, tree_boundary_test) { sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(0, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_right(n)); - + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_forward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_forward_1(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_end(n)); + } + // Depth 1 forward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(1, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_2(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_forward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_forward_2(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_2(n), p_sharp_end(n)); + } // Depth 2 forward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(2, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_3(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_forward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_forward_3(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_3(n), p_sharp_end(n)); + } // Depth 3 forward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(3, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_4(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_forward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_forward_4(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_4(n), p_sharp_end(n)); + } // Depth 0 backward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(0, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, -1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_backward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_backward_1(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_end(n)); + } // Depth 1 backward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(1, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, -1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_2(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_backward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_backward_2(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_2(n), p_sharp_end(n)); + } // Depth 2 backward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(2, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, -1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_3(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_backward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_backward_3(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_3(n), p_sharp_end(n)); + } // Depth 3 backward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(3, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, -1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_left(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_backward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_backward_4(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_4(n), p_sharp_end(n)); + } - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_4(n), p_sharp_right(n)); } TEST(McmcUnitENuts, transition_test) {