From aa0c811be1c5669b20f72207ba03ccbcbc81b416 Mon Sep 17 00:00:00 2001 From: betanalpha Date: Sun, 11 Aug 2019 21:45:07 -0400 Subject: [PATCH 01/12] Expanded termination criterion --- src/stan/mcmc/hmc/nuts/base_nuts.hpp | 100 +++++++++++++++++++++------ 1 file changed, 80 insertions(+), 20 deletions(-) diff --git a/src/stan/mcmc/hmc/nuts/base_nuts.hpp b/src/stan/mcmc/hmc/nuts/base_nuts.hpp index 04444cd8518..761645ca740 100644 --- a/src/stan/mcmc/hmc/nuts/base_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/base_nuts.hpp @@ -87,10 +87,17 @@ namespace stan { ps_point z_sample(z_plus); ps_point z_propose(z_plus); - Eigen::VectorXd p_sharp_plus = this->hamiltonian_.dtau_dp(this->z_); - Eigen::VectorXd p_sharp_dummy = p_sharp_plus; - Eigen::VectorXd p_sharp_minus = p_sharp_plus; + Eigen::VectorXd p_sharp_plus_plus = this->hamiltonian_.dtau_dp(this->z_); + Eigen::VectorXd p_sharp_plus_minus = p_sharp_plus_plus; + Eigen::VectorXd p_sharp_minus_plus = p_sharp_plus_plus; + Eigen::VectorXd p_sharp_minus_minus = p_sharp_plus_plus; + Eigen::VectorXd rho = this->z_.p; + + Eigen::VectorXd p_plus_plus = this->z_.p; + Eigen::VectorXd p_plus_minus = this->z_.p; + Eigen::VectorXd p_minus_plus = this->z_.p; + Eigen::VectorXd p_minus_minus = this->z_.p; double log_sum_weight = 0; // log(exp(H0 - H0)) double H0 = this->hamiltonian_.H(this->z_); @@ -103,25 +110,32 @@ namespace stan { while (this->depth_ < this->max_depth_) { // Build a new subtree in a random direction - Eigen::VectorXd rho_subtree = Eigen::VectorXd::Zero(rho.size()); + Eigen::VectorXd rho_plus = Eigen::VectorXd::Zero(rho.size()); + Eigen::VectorXd rho_minus = Eigen::VectorXd::Zero(rho.size()); + bool valid_subtree = false; double log_sum_weight_subtree - = -std::numeric_limits::infinity(); + = -std::numeric_limits::infinity(); if (this->rand_uniform_() > 0.5) { this->z_.ps_point::operator=(z_plus); + rho_minus = rho; valid_subtree = build_tree(this->depth_, z_propose, - p_sharp_dummy, p_sharp_plus, rho_subtree, + p_sharp_plus_minus, p_sharp_plus_plus, + rho_plus, p_plus_minus, p_plus_plus, H0, 1, n_leapfrog, log_sum_weight_subtree, sum_metro_prob, logger); z_plus.ps_point::operator=(this->z_); + } else { this->z_.ps_point::operator=(z_minus); + rho_plus = rho; valid_subtree = build_tree(this->depth_, z_propose, - p_sharp_dummy, p_sharp_minus, rho_subtree, + p_sharp_minus_plus, p_sharp_minus_minus, + rho_minus, p_minus_plus, p_minus_minus, H0, -1, n_leapfrog, log_sum_weight_subtree, sum_metro_prob, logger); @@ -145,9 +159,30 @@ namespace stan { log_sum_weight = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); - // Break when NUTS criterion is no longer satisfied - rho += rho_subtree; - if (!compute_criterion(p_sharp_minus, p_sharp_plus, rho)) + // Break when no-u-turn criterion is no longer satisfied + rho = rho_minus + rho_plus; + + // Boundary check + bool persist_criterion = + compute_criterion(p_sharp_minus_minus, + p_sharp_plus_plus, + rho); + + // Extra internal checks + Eigen::VectorXd rho_extended = rho_minus + p_plus_minus; + + persist_criterion &= + compute_criterion(p_sharp_minus_minus, + p_sharp_plus_minus, + rho_extended); + + rho_extended = rho_plus + p_minus_plus; + persist_criterion &= + compute_criterion(p_sharp_minus_plus, + p_sharp_plus_plus, + rho_extended); + + if (!persist_criterion) break; } @@ -193,9 +228,11 @@ namespace stan { * * @param depth Depth of the desired subtree * @param z_propose State proposed from subtree - * @param p_sharp_left p_sharp from left boundary of returned tree - * @param p_sharp_right p_sharp from the right boundary of returned tree + * @param p_sharp_left p_sharp at the left boundary of returned tree + * @param p_sharp_right p_sharp at the right boundary of returned tree * @param rho Summed momentum across trajectory + * @param p_left p at the left boundary of returned tree + * @param p_right p at the right boundary of returned tree * @param H0 Hamiltonian of initial state * @param sign Direction in time to built subtree * @param n_leapfrog Summed number of leapfrog evaluations @@ -207,6 +244,8 @@ namespace stan { Eigen::VectorXd& p_sharp_left, Eigen::VectorXd& p_sharp_right, Eigen::VectorXd& rho, + Eigen::VectorXd& p_left, + Eigen::VectorXd& p_right, double H0, double sign, int& n_leapfrog, double& log_sum_weight, double& sum_metro_prob, callbacks::logger& logger) { @@ -231,23 +270,28 @@ namespace stan { sum_metro_prob += std::exp(H0 - h); z_propose = this->z_; - rho += this->z_.p; p_sharp_left = this->hamiltonian_.dtau_dp(this->z_); - p_sharp_right = p_sharp_left; + p_sharp_right = p_sharp_left; + + rho += this->z_.p; + p_left = this->z_.p; + p_right = p_left; return !this->divergent_; } // General recursion - Eigen::VectorXd p_sharp_dummy(this->z_.p.size()); // Build the left subtree double log_sum_weight_left = -std::numeric_limits::infinity(); + Eigen::VectorXd p_sharp_left_right(this->z_.p.size()); Eigen::VectorXd rho_left = Eigen::VectorXd::Zero(rho.size()); + Eigen::VectorXd p_left_right(this->z_.p.size()); bool valid_left = build_tree(depth - 1, z_propose, - p_sharp_left, p_sharp_dummy, rho_left, + p_sharp_left, p_sharp_left_right, + rho_left, p_left, p_left_right, H0, sign, n_leapfrog, log_sum_weight_left, sum_metro_prob, logger); @@ -258,11 +302,14 @@ namespace stan { ps_point z_propose_right(this->z_); double log_sum_weight_right = -std::numeric_limits::infinity(); + Eigen::VectorXd p_sharp_right_left(this->z_.p.size()); Eigen::VectorXd rho_right = Eigen::VectorXd::Zero(rho.size()); + Eigen::VectorXd p_right_left(this->z_.p.size()); bool valid_right = build_tree(depth - 1, z_propose_right, - p_sharp_dummy, p_sharp_right, rho_right, + p_sharp_right_left, p_sharp_right, + rho_right, p_right_left, p_right, H0, sign, n_leapfrog, log_sum_weight_right, sum_metro_prob, logger); @@ -283,11 +330,24 @@ namespace stan { if (this->rand_uniform_() < accept_prob) z_propose = z_propose_right; } - + Eigen::VectorXd rho_subtree = rho_left + rho_right; rho += rho_subtree; - - return compute_criterion(p_sharp_left, p_sharp_right, rho_subtree); + + // Boundary check + bool persist_criterion = + compute_criterion(p_sharp_left, p_sharp_right, rho_subtree); + + // Extra internal checks + rho_subtree = rho_left + p_right_left; + persist_criterion &= + compute_criterion(p_sharp_left, p_sharp_right_left, rho_subtree); + + rho_subtree = rho_right + p_left_right; + persist_criterion &= + compute_criterion(p_sharp_left_right, p_sharp_right, rho_subtree); + + return persist_criterion; } int depth_; From f2e6464b167910f5b4c077911d54e6d62e72c1d6 Mon Sep 17 00:00:00 2001 From: betanalpha Date: Mon, 12 Aug 2019 21:35:03 -0400 Subject: [PATCH 02/12] Add additional no-u-turn checks, clearer naming conventions --- src/stan/mcmc/hmc/nuts/base_nuts.hpp | 222 ++++++++++++++------------- 1 file changed, 119 insertions(+), 103 deletions(-) diff --git a/src/stan/mcmc/hmc/nuts/base_nuts.hpp b/src/stan/mcmc/hmc/nuts/base_nuts.hpp index 761645ca740..7c9e1768f73 100644 --- a/src/stan/mcmc/hmc/nuts/base_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/base_nuts.hpp @@ -81,70 +81,80 @@ namespace stan { this->hamiltonian_.sample_p(this->z_, this->rand_int_); this->hamiltonian_.init(this->z_, logger); - ps_point z_plus(this->z_); - ps_point z_minus(z_plus); + ps_point z_fwd(this->z_); // State at forward end of trajectory + ps_point z_bck(z_plus); // State at backward end of trajectory ps_point z_sample(z_plus); ps_point z_propose(z_plus); - Eigen::VectorXd p_sharp_plus_plus = this->hamiltonian_.dtau_dp(this->z_); - Eigen::VectorXd p_sharp_plus_minus = p_sharp_plus_plus; - Eigen::VectorXd p_sharp_minus_plus = p_sharp_plus_plus; - Eigen::VectorXd p_sharp_minus_minus = p_sharp_plus_plus; - + // Momentum and sharp momentum at forward end of forward subtree + Eigen::VectorXd p_fwd_fwd = this->z_.p; + Eigen::VectorXd p_sharp_fwd_fwd = this->hamiltonian_.dtau_dp(this->z_); + + // Momentum and sharp momentum at backward end of forward subtree + Eigen::VectorXd p_fwd_bkc = this->z_.p; + Eigen::VectorXd p_sharp_fwd_bck = p_sharp_fwd_fwd; + + // Momentum and sharp momentum at forward end of backward subtree + Eigen::VectorXd p_bck_fwd = this->z_.p; + Eigen::VectorXd p_sharp_bck_fwd = p_sharp_fwd_fwd; + + // Momentum and sharp momentum at backward end of backward subtree + Eigen::VectorXd p_bck_bck = this->z_.p; + Eigen::VectorXd p_sharp_bck_bck = p_sharp_fwd_fwd; + + // Integrated momenta along trajectory Eigen::VectorXd rho = this->z_.p; - - Eigen::VectorXd p_plus_plus = this->z_.p; - Eigen::VectorXd p_plus_minus = this->z_.p; - Eigen::VectorXd p_minus_plus = this->z_.p; - Eigen::VectorXd p_minus_minus = this->z_.p; + // Log sum of state weights (offset by H0) along trajectory double log_sum_weight = 0; // log(exp(H0 - H0)) double H0 = this->hamiltonian_.H(this->z_); int n_leapfrog = 0; double sum_metro_prob = 0; - // Build a trajectory until the NUTS criterion is no longer satisfied + // Build a trajectory until the no-u-turn + // criterion is no longer satisfied this->depth_ = 0; this->divergent_ = false; while (this->depth_ < this->max_depth_) { // Build a new subtree in a random direction - Eigen::VectorXd rho_plus = Eigen::VectorXd::Zero(rho.size()); - Eigen::VectorXd rho_minus = Eigen::VectorXd::Zero(rho.size()); - + Eigen::VectorXd rho_fwd = Eigen::VectorXd::Zero(rho.size()); + Eigen::VectorXd rho_bck = Eigen::VectorXd::Zero(rho.size()); + bool valid_subtree = false; double log_sum_weight_subtree - = -std::numeric_limits::infinity(); + = -std::numeric_limits::infinity(); if (this->rand_uniform_() > 0.5) { - this->z_.ps_point::operator=(z_plus); - rho_minus = rho; + // Extend the current trajectory forward + this->z_.ps_point::operator=(z_fwd); + rho_bck = rho; valid_subtree = build_tree(this->depth_, z_propose, - p_sharp_plus_minus, p_sharp_plus_plus, - rho_plus, p_plus_minus, p_plus_plus, + p_sharp_fwd_bck, p_sharp_fwd_fwd, + rho_fwd, p_fwd_bck, p_fwd_fwd, H0, 1, n_leapfrog, log_sum_weight_subtree, sum_metro_prob, logger); - z_plus.ps_point::operator=(this->z_); - + z_fwd.ps_point::operator=(this->z_); } else { - this->z_.ps_point::operator=(z_minus); - rho_plus = rho; + // Extend the current trajectory backwards + this->z_.ps_point::operator=(z_bck); + rho_fwd = rho; valid_subtree = build_tree(this->depth_, z_propose, - p_sharp_minus_plus, p_sharp_minus_minus, - rho_minus, p_minus_plus, p_minus_minus, + p_sharp_bck_fwd, p_sharp_bck_bck, + rho_minus, p_bck_fwd, p_bck_bck, H0, -1, n_leapfrog, log_sum_weight_subtree, sum_metro_prob, logger); - z_minus.ps_point::operator=(this->z_); + z_bck.ps_point::operator=(this->z_); } if (!valid_subtree) break; - // Sample from an accepted subtree + // Sample from accepted subtree ++(this->depth_); if (log_sum_weight_subtree > log_sum_weight) { @@ -161,27 +171,27 @@ namespace stan { // Break when no-u-turn criterion is no longer satisfied rho = rho_minus + rho_plus; - - // Boundary check - bool persist_criterion = - compute_criterion(p_sharp_minus_minus, - p_sharp_plus_plus, + + // Demand satisfaction around merged subtrees + bool persist_criterion = + compute_criterion(p_sharp_bkc_bck, + p_sharp_fwd_fwd, rho); - - // Extra internal checks - Eigen::VectorXd rho_extended = rho_minus + p_plus_minus; - - persist_criterion &= - compute_criterion(p_sharp_minus_minus, - p_sharp_plus_minus, + + // Demand satisfaction between subtrees + Eigen::VectorXd rho_extended = rho_bck + p_fwd_bck; + + persist_criterion &= + compute_criterion(p_sharp_bck_bck, + p_sharp_fwd_bck, rho_extended); - - rho_extended = rho_plus + p_minus_plus; - persist_criterion &= - compute_criterion(p_sharp_minus_plus, - p_sharp_plus_plus, + + rho_extended = rho_fwd + p_bck_fwd; + persist_criterion &= + compute_criterion(p_sharp_bck_fwd, + p_sharp_fwd_fwd, rho_extended); - + if (!persist_criterion) break; } @@ -228,11 +238,11 @@ namespace stan { * * @param depth Depth of the desired subtree * @param z_propose State proposed from subtree - * @param p_sharp_left p_sharp at the left boundary of returned tree - * @param p_sharp_right p_sharp at the right boundary of returned tree + * @param p_sharp_beg Sharp momentum at beginning of new tree + * @param p_sharp_end Sharp momentum at end of new tree * @param rho Summed momentum across trajectory - * @param p_left p at the left boundary of returned tree - * @param p_right p at the right boundary of returned tree + * @param p_beg Momentum at beginning of returned tree + * @param p_end Momentum at end of returned tree * @param H0 Hamiltonian of initial state * @param sign Direction in time to built subtree * @param n_leapfrog Summed number of leapfrog evaluations @@ -241,11 +251,11 @@ namespace stan { * @param logger Logger for messages */ bool build_tree(int depth, ps_point& z_propose, - Eigen::VectorXd& p_sharp_left, - Eigen::VectorXd& p_sharp_right, + Eigen::VectorXd& p_sharp_beg, + Eigen::VectorXd& p_sharp_end, Eigen::VectorXd& rho, - Eigen::VectorXd& p_left, - Eigen::VectorXd& p_right, + Eigen::VectorXd& p_beg, + Eigen::VectorXd& p_end, double H0, double sign, int& n_leapfrog, double& log_sum_weight, double& sum_metro_prob, callbacks::logger& logger) { @@ -271,82 +281,88 @@ namespace stan { z_propose = this->z_; - p_sharp_left = this->hamiltonian_.dtau_dp(this->z_); - p_sharp_right = p_sharp_left; - + p_sharp_beg = this->hamiltonian_.dtau_dp(this->z_); + p_sharp_end = p_sharp_beg; + rho += this->z_.p; - p_left = this->z_.p; - p_right = p_left; + p_beg = this->z_.p; + p_end = p_beg; return !this->divergent_; } // General recursion - // Build the left subtree - double log_sum_weight_left = -std::numeric_limits::infinity(); - Eigen::VectorXd p_sharp_left_right(this->z_.p.size()); - Eigen::VectorXd rho_left = Eigen::VectorXd::Zero(rho.size()); - Eigen::VectorXd p_left_right(this->z_.p.size()); + // Build the initial subtree + double log_sum_weight_init = -std::numeric_limits::infinity(); + + // Momentum and sharp momentum at end of the initial subtree + Eigen::VectorXd p_init_end(this->z_.p.size()); + Eigen::VectorXd p_sharp_init_end(this->z_.p.size()); + + Eigen::VectorXd rho_init = Eigen::VectorXd::Zero(rho.size()); - bool valid_left + bool valid_init = build_tree(depth - 1, z_propose, - p_sharp_left, p_sharp_left_right, - rho_left, p_left, p_left_right, + p_sharp_beg, p_sharp_init_end, + rho_init, p_beg, p_init_end, H0, sign, n_leapfrog, - log_sum_weight_left, sum_metro_prob, + log_sum_weight_init, sum_metro_prob, logger); - if (!valid_left) return false; + if (!valid_init) return false; + + // Build the final subtree + ps_point z_propose_final(this->z_); + + double log_sum_weight_final = -std::numeric_limits::infinity(); - // Build the right subtree - ps_point z_propose_right(this->z_); + // Momentum and sharp momentum at beginning of the final subtree + Eigen::VectorXd p_final_beg(this->z_.p.size()); + Eigen::VectorXd p_sharp_final_beg(this->z_.p.size()); - double log_sum_weight_right = -std::numeric_limits::infinity(); - Eigen::VectorXd p_sharp_right_left(this->z_.p.size()); - Eigen::VectorXd rho_right = Eigen::VectorXd::Zero(rho.size()); - Eigen::VectorXd p_right_left(this->z_.p.size()); + Eigen::VectorXd rho_final = Eigen::VectorXd::Zero(rho.size()); - bool valid_right - = build_tree(depth - 1, z_propose_right, - p_sharp_right_left, p_sharp_right, - rho_right, p_right_left, p_right, + bool valid_final + = build_tree(depth - 1, z_propose_final, + p_sharp_final_beg, p_sharp_end, + rho_final, p_final_beg, p_end, H0, sign, n_leapfrog, - log_sum_weight_right, sum_metro_prob, + log_sum_weight_final, sum_metro_prob, logger); - if (!valid_right) return false; + if (!valid_final) return false; // Multinomial sample from right subtree double log_sum_weight_subtree - = math::log_sum_exp(log_sum_weight_left, log_sum_weight_right); + = math::log_sum_exp(log_sum_weight_init, log_sum_weight_final); log_sum_weight = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); - if (log_sum_weight_right > log_sum_weight_subtree) { - z_propose = z_propose_right; + if (log_sum_weight_final > log_sum_weight_subtree) { + z_propose = z_propose_final; } else { double accept_prob - = std::exp(log_sum_weight_right - log_sum_weight_subtree); + = std::exp(log_sum_weight_final - log_sum_weight_subtree); if (this->rand_uniform_() < accept_prob) - z_propose = z_propose_right; + z_propose = z_propose_final; } - - Eigen::VectorXd rho_subtree = rho_left + rho_right; + + Eigen::VectorXd rho_subtree = rho_init + rho_final; rho += rho_subtree; - - // Boundary check - bool persist_criterion = - compute_criterion(p_sharp_left, p_sharp_right, rho_subtree); - - // Extra internal checks - rho_subtree = rho_left + p_right_left; - persist_criterion &= - compute_criterion(p_sharp_left, p_sharp_right_left, rho_subtree); - - rho_subtree = rho_right + p_left_right; - persist_criterion &= - compute_criterion(p_sharp_left_right, p_sharp_right, rho_subtree); - + + // Demand satisfaction around merged subtrees + bool persist_criterion = + compute_criterion(p_sharp_beg, p_sharp_end, rho_subtree); + + // Demand satisfaction between subtrees + rho_subtree = rho_init + p_final_beg; + persist_criterion &= + compute_criterion(p_sharp_beg, p_sharp_final_beg, rho_subtree); + + rho_subtree = rho_final + p_init_end; + persist_criterion &= + compute_criterion(p_sharp_init_end, p_sharp_end, rho_subtree); + return persist_criterion; } From 17ad3e200775bc44f4c1fdf74189fb7ffa677d53 Mon Sep 17 00:00:00 2001 From: betanalpha Date: Mon, 12 Aug 2019 21:49:31 -0400 Subject: [PATCH 03/12] Fix typos --- src/stan/mcmc/hmc/nuts/base_nuts.hpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/stan/mcmc/hmc/nuts/base_nuts.hpp b/src/stan/mcmc/hmc/nuts/base_nuts.hpp index 7c9e1768f73..f0442c5ba95 100644 --- a/src/stan/mcmc/hmc/nuts/base_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/base_nuts.hpp @@ -82,17 +82,17 @@ namespace stan { this->hamiltonian_.init(this->z_, logger); ps_point z_fwd(this->z_); // State at forward end of trajectory - ps_point z_bck(z_plus); // State at backward end of trajectory + ps_point z_bck(z_fwd); // State at backward end of trajectory - ps_point z_sample(z_plus); - ps_point z_propose(z_plus); + ps_point z_sample(z_fwd); + ps_point z_propose(z_fwd); // Momentum and sharp momentum at forward end of forward subtree Eigen::VectorXd p_fwd_fwd = this->z_.p; Eigen::VectorXd p_sharp_fwd_fwd = this->hamiltonian_.dtau_dp(this->z_); // Momentum and sharp momentum at backward end of forward subtree - Eigen::VectorXd p_fwd_bkc = this->z_.p; + Eigen::VectorXd p_fwd_bck = this->z_.p; Eigen::VectorXd p_sharp_fwd_bck = p_sharp_fwd_fwd; // Momentum and sharp momentum at forward end of backward subtree @@ -145,7 +145,7 @@ namespace stan { valid_subtree = build_tree(this->depth_, z_propose, p_sharp_bck_fwd, p_sharp_bck_bck, - rho_minus, p_bck_fwd, p_bck_bck, + rho_bck, p_bck_fwd, p_bck_bck, H0, -1, n_leapfrog, log_sum_weight_subtree, sum_metro_prob, logger); @@ -170,11 +170,11 @@ namespace stan { = math::log_sum_exp(log_sum_weight, log_sum_weight_subtree); // Break when no-u-turn criterion is no longer satisfied - rho = rho_minus + rho_plus; + rho = rho_bck + rho_fwd; // Demand satisfaction around merged subtrees bool persist_criterion = - compute_criterion(p_sharp_bkc_bck, + compute_criterion(p_sharp_bck_bck, p_sharp_fwd_fwd, rho); From 555efc0e7966d371ba507bdf013823086317a218 Mon Sep 17 00:00:00 2001 From: betanalpha Date: Mon, 12 Aug 2019 22:39:25 -0400 Subject: [PATCH 04/12] Update tests --- .../unit/mcmc/hmc/nuts/base_nuts_test.cpp | 50 ++++-- .../unit/mcmc/hmc/nuts/softabs_nuts_test.cpp | 161 +++++++++++------ .../unit/mcmc/hmc/nuts/unit_e_nuts_test.cpp | 162 ++++++++++++------ 3 files changed, 250 insertions(+), 123 deletions(-) diff --git a/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp b/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp index de309c11460..870ec31f17f 100644 --- a/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp +++ b/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp @@ -149,9 +149,12 @@ TEST(McmcNutsBaseNuts, build_tree_test) { stan::mcmc::ps_point z_propose(model_size); - Eigen::VectorXd p_sharp_left = Eigen::VectorXd::Zero(model_size); - Eigen::VectorXd p_sharp_right = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_begin = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_sharp_begin = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_end = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_sharp_end = Eigen::VectorXd::Zero(model_size); Eigen::VectorXd rho = z_init.p; + double log_sum_weight = -std::numeric_limits::infinity(); double H0 = -0.1; @@ -170,15 +173,18 @@ TEST(McmcNutsBaseNuts, build_tree_test) { stan::callbacks::stream_logger logger(debug, info, warn, error, fatal); bool valid_subtree = sampler.build_tree(3, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); EXPECT_TRUE(valid_subtree); EXPECT_EQ(init_momentum * (n_leapfrog + 1), rho(0)); - EXPECT_EQ(1, p_sharp_left(0)); - EXPECT_EQ(1, p_sharp_right(0)); + EXPECT_EQ(1.5, p_begin(0)); + EXPECT_EQ(1, p_sharp_begin(0)); + EXPECT_EQ(1.5, p_end(0)); + EXPECT_EQ(1, p_sharp_end(0)); EXPECT_EQ(8 * init_momentum, sampler.z().q(0)); EXPECT_EQ(init_momentum, sampler.z().p(0)); @@ -207,9 +213,12 @@ TEST(McmcNutsBaseNuts, rho_aggregation_test) { stan::mcmc::ps_point z_propose(model_size); - Eigen::VectorXd p_sharp_left = Eigen::VectorXd::Zero(model_size); - Eigen::VectorXd p_sharp_right = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_begin = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_sharp_begin = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_end = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_sharp_end = Eigen::VectorXd::Zero(model_size); Eigen::VectorXd rho = z_init.p; + double log_sum_weight = -std::numeric_limits::infinity(); double H0 = -0.1; @@ -228,18 +237,19 @@ TEST(McmcNutsBaseNuts, rho_aggregation_test) { stan::callbacks::stream_logger logger(debug, info, warn, error, fatal); sampler.build_tree(3, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - EXPECT_EQ(7, sampler.rho_values.size()); + EXPECT_EQ(7 * 3, sampler.rho_values.size()); EXPECT_EQ(2 * init_momentum, sampler.rho_values.at(0)); EXPECT_EQ(2 * init_momentum, sampler.rho_values.at(1)); - EXPECT_EQ(4 * init_momentum, sampler.rho_values.at(2)); + EXPECT_EQ(2 * init_momentum, sampler.rho_values.at(2)); EXPECT_EQ(2 * init_momentum, sampler.rho_values.at(3)); EXPECT_EQ(2 * init_momentum, sampler.rho_values.at(4)); - EXPECT_EQ(4 * init_momentum, sampler.rho_values.at(5)); - EXPECT_EQ(8 * init_momentum, sampler.rho_values.at(6)); + EXPECT_EQ(2 * init_momentum, sampler.rho_values.at(5)); + EXPECT_EQ(4 * init_momentum, sampler.rho_values.at(6)); } TEST(McmcNutsBaseNuts, divergence_test) { @@ -255,9 +265,12 @@ TEST(McmcNutsBaseNuts, divergence_test) { stan::mcmc::ps_point z_propose(model_size); - Eigen::VectorXd p_sharp_left = Eigen::VectorXd::Zero(model_size); - Eigen::VectorXd p_sharp_right = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_begin = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_sharp_begin = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_end = Eigen::VectorXd::Zero(model_size); + Eigen::VectorXd p_sharp_end = Eigen::VectorXd::Zero(model_size); Eigen::VectorXd rho = z_init.p; + double log_sum_weight = -std::numeric_limits::infinity(); double H0 = -0.1; @@ -279,7 +292,8 @@ TEST(McmcNutsBaseNuts, divergence_test) { sampler.z().V = -750; valid_subtree = sampler.build_tree(0, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); @@ -288,7 +302,8 @@ TEST(McmcNutsBaseNuts, divergence_test) { sampler.z().V = -250; valid_subtree = sampler.build_tree(0, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); @@ -298,7 +313,8 @@ TEST(McmcNutsBaseNuts, divergence_test) { sampler.z().V = 750; valid_subtree = sampler.build_tree(0, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); diff --git a/src/test/unit/mcmc/hmc/nuts/softabs_nuts_test.cpp b/src/test/unit/mcmc/hmc/nuts/softabs_nuts_test.cpp index 94c09614c70..1e07a2bdd40 100644 --- a/src/test/unit/mcmc/hmc/nuts/softabs_nuts_test.cpp +++ b/src/test/unit/mcmc/hmc/nuts/softabs_nuts_test.cpp @@ -38,9 +38,12 @@ TEST(McmcSoftAbsNuts, build_tree_test) { stan::mcmc::ps_point z_propose = z_init; - Eigen::VectorXd p_sharp_left = Eigen::VectorXd::Zero(z_init.p.size()); - Eigen::VectorXd p_sharp_right = Eigen::VectorXd::Zero(z_init.p.size()); + Eigen::VectorXd p_begin = Eigen::VectorXd::Zero(3); + Eigen::VectorXd p_sharp_begin = Eigen::VectorXd::Zero(3); + Eigen::VectorXd p_end = Eigen::VectorXd::Zero(3); + Eigen::VectorXd p_sharp_end = Eigen::VectorXd::Zero(3); Eigen::VectorXd rho = z_init.p; + double log_sum_weight = -std::numeric_limits::infinity(); double H0 = -0.1; @@ -48,7 +51,8 @@ TEST(McmcSoftAbsNuts, build_tree_test) { double sum_metro_prob = 0; bool valid_subtree = sampler.build_tree(3, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); @@ -60,6 +64,22 @@ TEST(McmcSoftAbsNuts, build_tree_test) { EXPECT_FLOAT_EQ(11.679803, rho(1)); EXPECT_FLOAT_EQ(-11.679803, rho(2)); + EXPECT_FLOAT_EQ(-1.0960016, p_begin(0)); + EXPECT_FLOAT_EQ(1.0960016, p_begin(1)); + EXPECT_FLOAT_EQ(-1.0960016, p_begin(2)); + + EXPECT_FLOAT_EQ(-0.83470845, p_sharp_begin(0)); + EXPECT_FLOAT_EQ(0.83470845, p_sharp_begin(1)); + EXPECT_FLOAT_EQ(-0.83470845, p_sharp_begin(2)); + + EXPECT_FLOAT_EQ(-1.5019561, p_end(0)); + EXPECT_FLOAT_EQ(1.5019561, p_end(1)); + EXPECT_FLOAT_EQ(-1.5019561, p_end(2)); + + EXPECT_FLOAT_EQ(-1.143881, p_sharp_end(0)); + EXPECT_FLOAT_EQ(1.143881, p_sharp_end(1)); + EXPECT_FLOAT_EQ(-1.143881, p_sharp_end(2)); + EXPECT_FLOAT_EQ(0.20423166, sampler.z().q(0)); EXPECT_FLOAT_EQ(-0.20423166, sampler.z().q(1)); EXPECT_FLOAT_EQ(0.20423166, sampler.z().q(2)); @@ -110,38 +130,46 @@ TEST(McmcSoftAbsNuts, tree_boundary_test) { metric.init(z_test, logger); softabs_integrator.evolve(z_test, metric, epsilon, logger); + Eigen::VectorXd p_forward_1 = z_test.p; Eigen::VectorXd p_sharp_forward_1 = metric.dtau_dp(z_test); softabs_integrator.evolve(z_test, metric, epsilon, logger); + Eigen::VectorXd p_forward_2 = z_test.p; Eigen::VectorXd p_sharp_forward_2 = metric.dtau_dp(z_test); softabs_integrator.evolve(z_test, metric, epsilon, logger); softabs_integrator.evolve(z_test, metric, epsilon, logger); + Eigen::VectorXd p_forward_3 = z_test.p; Eigen::VectorXd p_sharp_forward_3 = metric.dtau_dp(z_test); softabs_integrator.evolve(z_test, metric, epsilon, logger); softabs_integrator.evolve(z_test, metric, epsilon, logger); softabs_integrator.evolve(z_test, metric, epsilon, logger); softabs_integrator.evolve(z_test, metric, epsilon, logger); + Eigen::VectorXd p_forward_4 = z_test.p; Eigen::VectorXd p_sharp_forward_4 = metric.dtau_dp(z_test); z_test = z_init; metric.init(z_test, logger); softabs_integrator.evolve(z_test, metric, -epsilon, logger); + Eigen::VectorXd p_backward_1 = z_test.p; Eigen::VectorXd p_sharp_backward_1 = metric.dtau_dp(z_test); softabs_integrator.evolve(z_test, metric, -epsilon, logger); + Eigen::VectorXd p_backward_2 = z_test.p; Eigen::VectorXd p_sharp_backward_2 = metric.dtau_dp(z_test); softabs_integrator.evolve(z_test, metric, -epsilon, logger); softabs_integrator.evolve(z_test, metric, -epsilon, logger); + Eigen::VectorXd p_backward_3 = z_test.p; Eigen::VectorXd p_sharp_backward_3 = metric.dtau_dp(z_test); softabs_integrator.evolve(z_test, metric, -epsilon, logger); softabs_integrator.evolve(z_test, metric, -epsilon, logger); softabs_integrator.evolve(z_test, metric, -epsilon, logger); softabs_integrator.evolve(z_test, metric, -epsilon, logger); + Eigen::VectorXd p_backward_4 = z_test.p; Eigen::VectorXd p_sharp_backward_4 = metric.dtau_dp(z_test); // Check expected tree boundaries to those dynamically geneated by NUTS @@ -153,9 +181,12 @@ TEST(McmcSoftAbsNuts, tree_boundary_test) { stan::mcmc::ps_point z_propose = z_init; - Eigen::VectorXd p_sharp_left = Eigen::VectorXd::Zero(z_init.p.size()); - Eigen::VectorXd p_sharp_right = Eigen::VectorXd::Zero(z_init.p.size()); + Eigen::VectorXd p_begin = Eigen::VectorXd::Zero(3); + Eigen::VectorXd p_sharp_begin = Eigen::VectorXd::Zero(3); + Eigen::VectorXd p_end = Eigen::VectorXd::Zero(3); + Eigen::VectorXd p_sharp_end = Eigen::VectorXd::Zero(3); Eigen::VectorXd rho = z_init.p; + double log_sum_weight = -std::numeric_limits::infinity(); double H0 = -0.1; @@ -166,121 +197,145 @@ TEST(McmcSoftAbsNuts, tree_boundary_test) { sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(0, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_forward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_forward_1(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_end(n)); + } // Depth 1 forward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(1, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_2(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_forward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_forward_2(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_2(n), p_sharp_end(n)); + } // Depth 2 forward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(2, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_3(n), p_sharp_right(n)); - + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_forward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_forward_3(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_3(n), p_sharp_end(n)); + } + // Depth 3 forward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(3, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_left(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_forward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_begin(n)); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_4(n), p_sharp_right(n)); + EXPECT_FLOAT_EQ(p_forward_4(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_4(n), p_sharp_end(n)); + } // Depth 0 backward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(0, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, -1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_backward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_backward_1(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_end(n)); + } // Depth 1 backward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(1, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, -1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_2(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_backward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_backward_2(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_2(n), p_sharp_end(n)); + } // Depth 2 backward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(2, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, -1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_3(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_backward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_backward_3(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_3(n), p_sharp_end(n)); + } // Depth 3 backward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(3, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, -1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_4(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_backward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_backward_4(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_4(n), p_sharp_end(n)); + } } TEST(McmcSoftAbsNuts, transition_test) { diff --git a/src/test/unit/mcmc/hmc/nuts/unit_e_nuts_test.cpp b/src/test/unit/mcmc/hmc/nuts/unit_e_nuts_test.cpp index e34a6d69bca..959f26d92ca 100644 --- a/src/test/unit/mcmc/hmc/nuts/unit_e_nuts_test.cpp +++ b/src/test/unit/mcmc/hmc/nuts/unit_e_nuts_test.cpp @@ -38,9 +38,12 @@ TEST(McmcUnitENuts, build_tree_test) { stan::mcmc::ps_point z_propose = z_init; - Eigen::VectorXd p_sharp_left = Eigen::VectorXd::Zero(z_init.p.size()); - Eigen::VectorXd p_sharp_right = Eigen::VectorXd::Zero(z_init.p.size()); + Eigen::VectorXd p_begin = Eigen::VectorXd::Zero(z_init.p.size()); + Eigen::VectorXd p_sharp_begin = Eigen::VectorXd::Zero(z_init.p.size()); + Eigen::VectorXd p_end = Eigen::VectorXd::Zero(z_init.p.size()); + Eigen::VectorXd p_sharp_end = Eigen::VectorXd::Zero(z_init.p.size()); Eigen::VectorXd rho = z_init.p; + double log_sum_weight = -std::numeric_limits::infinity(); double H0 = -0.1; @@ -48,7 +51,8 @@ TEST(McmcUnitENuts, build_tree_test) { double sum_metro_prob = 0; bool valid_subtree = sampler.build_tree(3, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); @@ -60,6 +64,22 @@ TEST(McmcUnitENuts, build_tree_test) { EXPECT_FLOAT_EQ(-11.401228, rho(0)); EXPECT_FLOAT_EQ(11.401228, rho(1)); EXPECT_FLOAT_EQ(-11.401228, rho(2)); + + EXPECT_FLOAT_EQ(-1.09475, p_begin(0)); + EXPECT_FLOAT_EQ(1.09475, p_begin(1)); + EXPECT_FLOAT_EQ(-1.09475, p_begin(2)); + + EXPECT_FLOAT_EQ(-1.09475, p_sharp_begin(0)); + EXPECT_FLOAT_EQ(1.09475, p_sharp_begin(1)); + EXPECT_FLOAT_EQ(-1.09475, p_sharp_begin(2)); + + EXPECT_FLOAT_EQ(-1.4131583, p_end(0)); + EXPECT_FLOAT_EQ(1.4131583, p_end(1)); + EXPECT_FLOAT_EQ(-1.4131583, p_end(2)); + + EXPECT_FLOAT_EQ(-1.4131583, p_sharp_end(0)); + EXPECT_FLOAT_EQ(1.4131583, p_sharp_end(1)); + EXPECT_FLOAT_EQ(-1.4131583, p_sharp_end(2)); EXPECT_FLOAT_EQ(-0.022019938, sampler.z().q(0)); EXPECT_FLOAT_EQ(0.022019938, sampler.z().q(1)); @@ -111,38 +131,46 @@ TEST(McmcUnitENuts, tree_boundary_test) { metric.init(z_test, logger); unit_e_integrator.evolve(z_test, metric, epsilon, logger); + Eigen::VectorXd p_forward_1 = z_test.p; Eigen::VectorXd p_sharp_forward_1 = metric.dtau_dp(z_test); unit_e_integrator.evolve(z_test, metric, epsilon, logger); + Eigen::VectorXd p_forward_2 = z_test.p; Eigen::VectorXd p_sharp_forward_2 = metric.dtau_dp(z_test); unit_e_integrator.evolve(z_test, metric, epsilon, logger); unit_e_integrator.evolve(z_test, metric, epsilon, logger); + Eigen::VectorXd p_forward_3 = z_test.p; Eigen::VectorXd p_sharp_forward_3 = metric.dtau_dp(z_test); unit_e_integrator.evolve(z_test, metric, epsilon, logger); unit_e_integrator.evolve(z_test, metric, epsilon, logger); unit_e_integrator.evolve(z_test, metric, epsilon, logger); unit_e_integrator.evolve(z_test, metric, epsilon, logger); + Eigen::VectorXd p_forward_4 = z_test.p; Eigen::VectorXd p_sharp_forward_4 = metric.dtau_dp(z_test); z_test = z_init; metric.init(z_test, logger); unit_e_integrator.evolve(z_test, metric, -epsilon, logger); + Eigen::VectorXd p_backward_1 = z_test.p; Eigen::VectorXd p_sharp_backward_1 = metric.dtau_dp(z_test); unit_e_integrator.evolve(z_test, metric, -epsilon, logger); + Eigen::VectorXd p_backward_2 = z_test.p; Eigen::VectorXd p_sharp_backward_2 = metric.dtau_dp(z_test); unit_e_integrator.evolve(z_test, metric, -epsilon, logger); unit_e_integrator.evolve(z_test, metric, -epsilon, logger); + Eigen::VectorXd p_backward_3 = z_test.p; Eigen::VectorXd p_sharp_backward_3 = metric.dtau_dp(z_test); unit_e_integrator.evolve(z_test, metric, -epsilon, logger); unit_e_integrator.evolve(z_test, metric, -epsilon, logger); unit_e_integrator.evolve(z_test, metric, -epsilon, logger); unit_e_integrator.evolve(z_test, metric, -epsilon, logger); + Eigen::VectorXd p_backward_4 = z_test.p; Eigen::VectorXd p_sharp_backward_4 = metric.dtau_dp(z_test); // Check expected tree boundaries to those dynamically geneated by NUTS @@ -154,9 +182,12 @@ TEST(McmcUnitENuts, tree_boundary_test) { stan::mcmc::ps_point z_propose = z_init; - Eigen::VectorXd p_sharp_left = Eigen::VectorXd::Zero(z_init.p.size()); - Eigen::VectorXd p_sharp_right = Eigen::VectorXd::Zero(z_init.p.size()); + Eigen::VectorXd p_begin = Eigen::VectorXd::Zero(z_init.p.size()); + Eigen::VectorXd p_sharp_begin = Eigen::VectorXd::Zero(z_init.p.size()); + Eigen::VectorXd p_end = Eigen::VectorXd::Zero(z_init.p.size()); + Eigen::VectorXd p_sharp_end = Eigen::VectorXd::Zero(z_init.p.size()); Eigen::VectorXd rho = z_init.p; + double log_sum_weight = -std::numeric_limits::infinity(); double H0 = -0.1; @@ -167,121 +198,146 @@ TEST(McmcUnitENuts, tree_boundary_test) { sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(0, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_right(n)); - + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_forward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_forward_1(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_end(n)); + } + // Depth 1 forward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(1, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_2(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_forward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_forward_2(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_2(n), p_sharp_end(n)); + } // Depth 2 forward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(2, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_3(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_forward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_forward_3(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_3(n), p_sharp_end(n)); + } // Depth 3 forward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(3, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, 1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_forward_4(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_forward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_forward_4(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_forward_4(n), p_sharp_end(n)); + } // Depth 0 backward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(0, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, -1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_backward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_backward_1(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_end(n)); + } // Depth 1 backward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(1, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, -1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_2(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_backward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_backward_2(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_2(n), p_sharp_end(n)); + } // Depth 2 backward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(2, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, -1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_left(n)); - - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_3(n), p_sharp_right(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_backward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_backward_3(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_3(n), p_sharp_end(n)); + } // Depth 3 backward sampler.z() = z_init; sampler.init_hamiltonian(logger); sampler.build_tree(3, z_propose, - p_sharp_left, p_sharp_right, rho, + p_sharp_begin, p_sharp_end, + rho, p_begin, p_end, H0, -1, n_leapfrog, log_sum_weight, sum_metro_prob, logger); - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_left(n)); + for (int n = 0; n < rho.size(); ++n) { + EXPECT_FLOAT_EQ(p_backward_1(n), p_begin(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_1(n), p_sharp_begin(n)); + + EXPECT_FLOAT_EQ(p_backward_4(n), p_end(n)); + EXPECT_FLOAT_EQ(p_sharp_backward_4(n), p_sharp_end(n)); + } - for (int n = 0; n < rho.size(); ++n) - EXPECT_FLOAT_EQ(p_sharp_backward_4(n), p_sharp_right(n)); } TEST(McmcUnitENuts, transition_test) { From cdf65416a6bb337a07954cd5b967469a3aef3739 Mon Sep 17 00:00:00 2001 From: betanalpha Date: Wed, 14 Aug 2019 18:32:23 -0400 Subject: [PATCH 05/12] Update performance test --- src/test/performance/logistic_test.cpp | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/test/performance/logistic_test.cpp b/src/test/performance/logistic_test.cpp index 0efe715a99c..33c8c14d471 100644 --- a/src/test/performance/logistic_test.cpp +++ b/src/test/performance/logistic_test.cpp @@ -108,13 +108,13 @@ TEST_F(performance, values_from_tagged_version) { << "last tagged version, 2.17.0, had " << N_values << " elements"; std::vector first_run = last_draws_per_run[0]; - EXPECT_FLOAT_EQ(-65.781998, first_run[0]) + EXPECT_FLOAT_EQ(-66.222504, first_run[0]) << "lp__: index 0"; EXPECT_FLOAT_EQ(1.0, first_run[1]) << "accept_stat__: index 1"; - EXPECT_FLOAT_EQ(0.76853198, first_run[2]) + EXPECT_FLOAT_EQ(1.00182, first_run[2]) << "stepsize__: index 2"; EXPECT_FLOAT_EQ(2, first_run[3]) @@ -126,13 +126,13 @@ TEST_F(performance, values_from_tagged_version) { EXPECT_FLOAT_EQ(0, first_run[5]) << "divergent__: index 5"; - EXPECT_FLOAT_EQ(66.6695, first_run[6]) + EXPECT_FLOAT_EQ(67.6213, first_run[6]) << "energy__: index 6"; - EXPECT_FLOAT_EQ(1.55186, first_run[7]) + EXPECT_FLOAT_EQ(1.36304, first_run[7]) << "beta.1: index 7"; - EXPECT_FLOAT_EQ(-0.52400702, first_run[8]) + EXPECT_FLOAT_EQ(-0.86114103, first_run[8]) << "beta.2: index 8"; matches_tagged_version = !HasNonfatalFailure(); From 104973d37217292a17dccdb2f32f7d97c93f5d15 Mon Sep 17 00:00:00 2001 From: betanalpha Date: Wed, 28 Aug 2019 16:46:29 -0400 Subject: [PATCH 06/12] Update rho aggregation test --- src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp b/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp index 870ec31f17f..a46af184590 100644 --- a/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp +++ b/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp @@ -243,13 +243,13 @@ TEST(McmcNutsBaseNuts, rho_aggregation_test) { sum_metro_prob, logger); EXPECT_EQ(7 * 3, sampler.rho_values.size()); - EXPECT_EQ(2 * init_momentum, sampler.rho_values.at(0)); - EXPECT_EQ(2 * init_momentum, sampler.rho_values.at(1)); - EXPECT_EQ(2 * init_momentum, sampler.rho_values.at(2)); - EXPECT_EQ(2 * init_momentum, sampler.rho_values.at(3)); - EXPECT_EQ(2 * init_momentum, sampler.rho_values.at(4)); - EXPECT_EQ(2 * init_momentum, sampler.rho_values.at(5)); - EXPECT_EQ(4 * init_momentum, sampler.rho_values.at(6)); + + std::vector rho_scales = {2, 2, 2, 2, 2, 2, 4, 3, 3, 2, 2, + 2, 2, 2, 2, 4, 3, 3, 8, 5, 5}; + + for (int n = 0; n < 21; ++n) { + EXPECT_EQ(rho_scales[n] * init_momentum, sampler.rho_values[n]); + } } TEST(McmcNutsBaseNuts, divergence_test) { From 6041a8758eb7c82bc693fbe13de41836cc2881b5 Mon Sep 17 00:00:00 2001 From: betanalpha Date: Thu, 29 Aug 2019 14:14:31 -0400 Subject: [PATCH 07/12] Update base nuts test --- .../unit/mcmc/hmc/nuts/base_nuts_test.cpp | 30 +++++++++++++++---- 1 file changed, 25 insertions(+), 5 deletions(-) diff --git a/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp b/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp index a46af184590..1955083a0d8 100644 --- a/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp +++ b/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp @@ -244,12 +244,32 @@ TEST(McmcNutsBaseNuts, rho_aggregation_test) { EXPECT_EQ(7 * 3, sampler.rho_values.size()); - std::vector rho_scales = {2, 2, 2, 2, 2, 2, 4, 3, 3, 2, 2, - 2, 2, 2, 2, 4, 3, 3, 8, 5, 5}; + // Trajectory component spanning rhos + EXPECT_EQ(2 * init_momentum, sampler.rho_values[0]); + EXPECT_EQ(2 * init_momentum, sampler.rho_values[3]); + EXPECT_EQ(4 * init_momentum, sampler.rho_values[6]); + EXPECT_EQ(2 * init_momentum, sampler.rho_values[9]); + EXPECT_EQ(2 * init_momentum, sampler.rho_values[12]); + EXPECT_EQ(4 * init_momentum, sampler.rho_values[15]); + EXPECT_EQ(8 * init_momentum, sampler.rho_values[18]); + + // Cross trajectory component rhos + EXPECT_EQ(2 * init_momentum, sampler.rho_values[1]); + EXPECT_EQ(2 * init_momentum, sampler.rho_values[4]); + EXPECT_EQ(3 * init_momentum, sampler.rho_values[7]); + EXPECT_EQ(2 * init_momentum, sampler.rho_values[10]); + EXPECT_EQ(2 * init_momentum, sampler.rho_values[13]); + EXPECT_EQ(3 * init_momentum, sampler.rho_values[16]); + EXPECT_EQ(5 * init_momentum, sampler.rho_values[19]); + + EXPECT_EQ(2 * init_momentum, sampler.rho_values[2]); + EXPECT_EQ(2 * init_momentum, sampler.rho_values[5]); + EXPECT_EQ(3 * init_momentum, sampler.rho_values[8]); + EXPECT_EQ(2 * init_momentum, sampler.rho_values[11]); + EXPECT_EQ(2 * init_momentum, sampler.rho_values[14]); + EXPECT_EQ(3 * init_momentum, sampler.rho_values[17]); + EXPECT_EQ(5 * init_momentum, sampler.rho_values[20]); - for (int n = 0; n < 21; ++n) { - EXPECT_EQ(rho_scales[n] * init_momentum, sampler.rho_values[n]); - } } TEST(McmcNutsBaseNuts, divergence_test) { From 8a65f5b4a265cd412a7f10baecb537d411a926ad Mon Sep 17 00:00:00 2001 From: betanalpha Date: Thu, 29 Aug 2019 14:15:08 -0400 Subject: [PATCH 08/12] Fix bug in updating edge momenta before each expansion --- src/stan/mcmc/hmc/nuts/base_nuts.hpp | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/src/stan/mcmc/hmc/nuts/base_nuts.hpp b/src/stan/mcmc/hmc/nuts/base_nuts.hpp index f0442c5ba95..2e153660c70 100644 --- a/src/stan/mcmc/hmc/nuts/base_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/base_nuts.hpp @@ -104,8 +104,8 @@ namespace stan { Eigen::VectorXd p_sharp_bck_bck = p_sharp_fwd_fwd; // Integrated momenta along trajectory - Eigen::VectorXd rho = this->z_.p; - + Eigen::VectorXd rho = this->z_.p.transpose(); + // Log sum of state weights (offset by H0) along trajectory double log_sum_weight = 0; // log(exp(H0 - H0)) double H0 = this->hamiltonian_.H(this->z_); @@ -130,6 +130,9 @@ namespace stan { // Extend the current trajectory forward this->z_.ps_point::operator=(z_fwd); rho_bck = rho; + p_bck_fwd = p_fwd_fwd; + p_sharp_bck_fwd = p_sharp_fwd_fwd; + valid_subtree = build_tree(this->depth_, z_propose, p_sharp_fwd_bck, p_sharp_fwd_fwd, @@ -142,6 +145,9 @@ namespace stan { // Extend the current trajectory backwards this->z_.ps_point::operator=(z_bck); rho_fwd = rho; + p_fwd_bck = p_bck_bck; + p_sharp_fwd_bck = p_sharp_bck_bck; + valid_subtree = build_tree(this->depth_, z_propose, p_sharp_bck_fwd, p_sharp_bck_bck, From 841c6a1c056514d02d37ce8bdb636e3d3fbc7f2d Mon Sep 17 00:00:00 2001 From: betanalpha Date: Thu, 29 Aug 2019 14:59:44 -0400 Subject: [PATCH 09/12] Lint --- src/stan/mcmc/hmc/nuts/base_nuts.hpp | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/stan/mcmc/hmc/nuts/base_nuts.hpp b/src/stan/mcmc/hmc/nuts/base_nuts.hpp index 2e153660c70..5c27cd3025f 100644 --- a/src/stan/mcmc/hmc/nuts/base_nuts.hpp +++ b/src/stan/mcmc/hmc/nuts/base_nuts.hpp @@ -105,7 +105,7 @@ namespace stan { // Integrated momenta along trajectory Eigen::VectorXd rho = this->z_.p.transpose(); - + // Log sum of state weights (offset by H0) along trajectory double log_sum_weight = 0; // log(exp(H0 - H0)) double H0 = this->hamiltonian_.H(this->z_); @@ -132,7 +132,7 @@ namespace stan { rho_bck = rho; p_bck_fwd = p_fwd_fwd; p_sharp_bck_fwd = p_sharp_fwd_fwd; - + valid_subtree = build_tree(this->depth_, z_propose, p_sharp_fwd_bck, p_sharp_fwd_fwd, @@ -147,7 +147,7 @@ namespace stan { rho_fwd = rho; p_fwd_bck = p_bck_bck; p_sharp_fwd_bck = p_sharp_bck_bck; - + valid_subtree = build_tree(this->depth_, z_propose, p_sharp_bck_fwd, p_sharp_bck_bck, From fc7df125b92dbe52d7620e036ff207de84ed25e6 Mon Sep 17 00:00:00 2001 From: betanalpha Date: Thu, 29 Aug 2019 15:00:16 -0400 Subject: [PATCH 10/12] Add test sensitive to edge momenta in transition --- src/test/unit/mcmc/hmc/mock_hmc.hpp | 2 +- .../unit/mcmc/hmc/nuts/base_nuts_test.cpp | 98 ++++++++++++++++++- 2 files changed, 96 insertions(+), 4 deletions(-) diff --git a/src/test/unit/mcmc/hmc/mock_hmc.hpp b/src/test/unit/mcmc/hmc/mock_hmc.hpp index 33c9d29028c..0a79bc9233d 100644 --- a/src/test/unit/mcmc/hmc/mock_hmc.hpp +++ b/src/test/unit/mcmc/hmc/mock_hmc.hpp @@ -64,7 +64,7 @@ namespace stan { // Ensures that NUTS non-termination criterion is always true Eigen::VectorXd dtau_dp(ps_point& z) { - return Eigen::VectorXd::Ones(this->model_.num_params_r()); + return z.q; } Eigen::VectorXd dphi_dq(ps_point& z, diff --git a/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp b/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp index 1955083a0d8..798f8e347e8 100644 --- a/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp +++ b/src/test/unit/mcmc/hmc/nuts/base_nuts_test.cpp @@ -20,13 +20,18 @@ namespace stan { mock_nuts(const mock_model &m, rng_t& rng) : base_nuts(m, rng) { } + + bool compute_criterion(Eigen::VectorXd& p_sharp_minus, + Eigen::VectorXd& p_sharp_plus, + Eigen::VectorXd& rho) { + return true; + } }; class rho_inspector_mock_nuts: public base_nuts { - public: std::vector rho_values; rho_inspector_mock_nuts(const mock_model &m, rng_t& rng) @@ -41,6 +46,26 @@ namespace stan { } }; + class edge_inspector_mock_nuts: public base_nuts { + public: + std::vector p_sharp_minus_values; + std::vector p_sharp_plus_values; + edge_inspector_mock_nuts(const mock_model &m, rng_t& rng) + : base_nuts(m, rng) + { } + + bool compute_criterion(Eigen::VectorXd& p_sharp_minus, + Eigen::VectorXd& p_sharp_plus, + Eigen::VectorXd& rho) { + p_sharp_minus_values.push_back(p_sharp_minus(0)); + p_sharp_plus_values.push_back(p_sharp_plus(0)); + return true; + } + }; + // Mock Hamiltonian template class divergent_hamiltonian @@ -182,9 +207,9 @@ TEST(McmcNutsBaseNuts, build_tree_test) { EXPECT_EQ(init_momentum * (n_leapfrog + 1), rho(0)); EXPECT_EQ(1.5, p_begin(0)); - EXPECT_EQ(1, p_sharp_begin(0)); + EXPECT_EQ(1.5, p_sharp_begin(0)); EXPECT_EQ(1.5, p_end(0)); - EXPECT_EQ(1, p_sharp_end(0)); + EXPECT_EQ(12, p_sharp_end(0)); EXPECT_EQ(8 * init_momentum, sampler.z().q(0)); EXPECT_EQ(init_momentum, sampler.z().p(0)); @@ -389,3 +414,70 @@ TEST(McmcNutsBaseNuts, transition) { EXPECT_EQ("", error.str()); EXPECT_EQ("", fatal.str()); } + +TEST(McmcNutsBaseNuts, transition_egde_momenta) { + + rng_t base_rng(0); + + int model_size = 1; + double init_momentum = 1.5; + + stan::mcmc::ps_point z_init(model_size); + z_init.q(0) = 0; + z_init.p(0) = init_momentum; + + stan::mcmc::mock_model model(model_size); + stan::mcmc::edge_inspector_mock_nuts sampler(model, base_rng); + + sampler.set_max_depth(2); + + sampler.set_nominal_stepsize(1); + sampler.set_stepsize_jitter(0); + sampler.sample_stepsize(); + sampler.z() = z_init; + + std::stringstream debug, info, warn, error, fatal; + stan::callbacks::stream_logger logger(debug, info, warn, error, fatal); + + stan::mcmc::sample init_sample(z_init.q, 0, 0); + + // Transition will expand trajectory until max_depth is hit + stan::mcmc::sample s = sampler.transition(init_sample, logger); + + EXPECT_EQ(2, sampler.depth_); + EXPECT_EQ((2 << (sampler.get_max_depth() - 1)) - 1, sampler.n_leapfrog_); + EXPECT_FALSE(sampler.divergent_); + + + EXPECT_EQ(9, sampler.p_sharp_minus_values.size()); + + // Depth 0 Transition Check + EXPECT_EQ(0, sampler.p_sharp_minus_values[0]); + EXPECT_EQ(init_momentum, sampler.p_sharp_plus_values[0]); + + EXPECT_EQ(0, sampler.p_sharp_minus_values[1]); + EXPECT_EQ(init_momentum, sampler.p_sharp_plus_values[1]); + + EXPECT_EQ(0, sampler.p_sharp_minus_values[2]); + EXPECT_EQ(init_momentum, sampler.p_sharp_plus_values[2]); + + // Depth 1 Build Tree Check + EXPECT_EQ(2 * init_momentum, sampler.p_sharp_minus_values[3]); + EXPECT_EQ(3 * init_momentum, sampler.p_sharp_plus_values[3]); + + EXPECT_EQ(2 * init_momentum, sampler.p_sharp_minus_values[4]); + EXPECT_EQ(3 * init_momentum, sampler.p_sharp_plus_values[4]); + + EXPECT_EQ(2 * init_momentum, sampler.p_sharp_minus_values[5]); + EXPECT_EQ(3 * init_momentum, sampler.p_sharp_plus_values[5]); + + // Depth 1 Transition Check + EXPECT_EQ(0, sampler.p_sharp_minus_values[6]); + EXPECT_EQ(3 * init_momentum, sampler.p_sharp_plus_values[6]); + + EXPECT_EQ(0, sampler.p_sharp_minus_values[7]); + EXPECT_EQ(2 * init_momentum, sampler.p_sharp_plus_values[7]); + + EXPECT_EQ(init_momentum, sampler.p_sharp_minus_values[8]); + EXPECT_EQ(3 * init_momentum, sampler.p_sharp_plus_values[8]); +} From b2deb9efee0c3afe144d478338024e7df2705172 Mon Sep 17 00:00:00 2001 From: betanalpha Date: Thu, 29 Aug 2019 15:09:04 -0400 Subject: [PATCH 11/12] Update performance test --- src/test/performance/logistic_test.cpp | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/test/performance/logistic_test.cpp b/src/test/performance/logistic_test.cpp index 33c8c14d471..f10c7875dec 100644 --- a/src/test/performance/logistic_test.cpp +++ b/src/test/performance/logistic_test.cpp @@ -108,31 +108,31 @@ TEST_F(performance, values_from_tagged_version) { << "last tagged version, 2.17.0, had " << N_values << " elements"; std::vector first_run = last_draws_per_run[0]; - EXPECT_FLOAT_EQ(-66.222504, first_run[0]) + EXPECT_FLOAT_EQ(-65.216301, first_run[0]) << "lp__: index 0"; - EXPECT_FLOAT_EQ(1.0, first_run[1]) + EXPECT_FLOAT_EQ(0.91851199, first_run[1]) << "accept_stat__: index 1"; - EXPECT_FLOAT_EQ(1.00182, first_run[2]) + EXPECT_FLOAT_EQ(0.76885599, first_run[2]) << "stepsize__: index 2"; EXPECT_FLOAT_EQ(2, first_run[3]) << "treedepth__: index 3"; - EXPECT_FLOAT_EQ(7, first_run[4]) + EXPECT_FLOAT_EQ(3, first_run[4]) << "n_leapfrog__: index 4"; EXPECT_FLOAT_EQ(0, first_run[5]) << "divergent__: index 5"; - EXPECT_FLOAT_EQ(67.6213, first_run[6]) + EXPECT_FLOAT_EQ(66.696503, first_run[6]) << "energy__: index 6"; - EXPECT_FLOAT_EQ(1.36304, first_run[7]) + EXPECT_FLOAT_EQ(1.3577, first_run[7]) << "beta.1: index 7"; - EXPECT_FLOAT_EQ(-0.86114103, first_run[8]) + EXPECT_FLOAT_EQ(-0.511895, first_run[8]) << "beta.2: index 8"; matches_tagged_version = !HasNonfatalFailure(); From 7bab596392a5669be69f4c0d84525ef9da0d588b Mon Sep 17 00:00:00 2001 From: betanalpha Date: Thu, 29 Aug 2019 22:59:12 -0400 Subject: [PATCH 12/12] Hand tuning performance test to pass on Jenkins machine --- src/test/performance/logistic_test.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/test/performance/logistic_test.cpp b/src/test/performance/logistic_test.cpp index f10c7875dec..de1b7a7607b 100644 --- a/src/test/performance/logistic_test.cpp +++ b/src/test/performance/logistic_test.cpp @@ -114,7 +114,7 @@ TEST_F(performance, values_from_tagged_version) { EXPECT_FLOAT_EQ(0.91851199, first_run[1]) << "accept_stat__: index 1"; - EXPECT_FLOAT_EQ(0.76885599, first_run[2]) + EXPECT_FLOAT_EQ(0.76885802, first_run[2]) << "stepsize__: index 2"; EXPECT_FLOAT_EQ(2, first_run[3]) @@ -132,7 +132,7 @@ TEST_F(performance, values_from_tagged_version) { EXPECT_FLOAT_EQ(1.3577, first_run[7]) << "beta.1: index 7"; - EXPECT_FLOAT_EQ(-0.511895, first_run[8]) + EXPECT_FLOAT_EQ(-0.51189202, first_run[8]) << "beta.2: index 8"; matches_tagged_version = !HasNonfatalFailure();