Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/issue 2799 robust no u turn #2800

Merged
merged 12 commits into from
Aug 30, 2019
176 changes: 126 additions & 50 deletions src/stan/mcmc/hmc/nuts/base_nuts.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -81,56 +81,80 @@ namespace stan {
this->hamiltonian_.sample_p(this->z_, this->rand_int_);
this->hamiltonian_.init(this->z_, logger);

ps_point z_plus(this->z_);
ps_point z_minus(z_plus);
ps_point z_fwd(this->z_); // State at forward end of trajectory
ps_point z_bck(z_fwd); // State at backward end of trajectory

ps_point z_sample(z_plus);
ps_point z_propose(z_plus);
ps_point z_sample(z_fwd);
ps_point z_propose(z_fwd);

Eigen::VectorXd p_sharp_plus = this->hamiltonian_.dtau_dp(this->z_);
Eigen::VectorXd p_sharp_dummy = p_sharp_plus;
Eigen::VectorXd p_sharp_minus = p_sharp_plus;
// Momentum and sharp momentum at forward end of forward subtree
Eigen::VectorXd p_fwd_fwd = this->z_.p;
Eigen::VectorXd p_sharp_fwd_fwd = this->hamiltonian_.dtau_dp(this->z_);

// Momentum and sharp momentum at backward end of forward subtree
Eigen::VectorXd p_fwd_bck = this->z_.p;
Eigen::VectorXd p_sharp_fwd_bck = p_sharp_fwd_fwd;

// Momentum and sharp momentum at forward end of backward subtree
Eigen::VectorXd p_bck_fwd = this->z_.p;
Eigen::VectorXd p_sharp_bck_fwd = p_sharp_fwd_fwd;

// Momentum and sharp momentum at backward end of backward subtree
Eigen::VectorXd p_bck_bck = this->z_.p;
Eigen::VectorXd p_sharp_bck_bck = p_sharp_fwd_fwd;

// Integrated momenta along trajectory
Eigen::VectorXd rho = this->z_.p;

// Log sum of state weights (offset by H0) along trajectory
double log_sum_weight = 0; // log(exp(H0 - H0))
double H0 = this->hamiltonian_.H(this->z_);
int n_leapfrog = 0;
double sum_metro_prob = 0;

// Build a trajectory until the NUTS criterion is no longer satisfied
// Build a trajectory until the no-u-turn
// criterion is no longer satisfied
this->depth_ = 0;
this->divergent_ = false;

while (this->depth_ < this->max_depth_) {
// Build a new subtree in a random direction
Eigen::VectorXd rho_subtree = Eigen::VectorXd::Zero(rho.size());
Eigen::VectorXd rho_fwd = Eigen::VectorXd::Zero(rho.size());
Eigen::VectorXd rho_bck = Eigen::VectorXd::Zero(rho.size());

bool valid_subtree = false;
double log_sum_weight_subtree
= -std::numeric_limits<double>::infinity();

if (this->rand_uniform_() > 0.5) {
this->z_.ps_point::operator=(z_plus);
// Extend the current trajectory forward
this->z_.ps_point::operator=(z_fwd);
rho_bck = rho;
valid_subtree
= build_tree(this->depth_, z_propose,
p_sharp_dummy, p_sharp_plus, rho_subtree,
p_sharp_fwd_bck, p_sharp_fwd_fwd,
rho_fwd, p_fwd_bck, p_fwd_fwd,
H0, 1, n_leapfrog,
log_sum_weight_subtree, sum_metro_prob,
logger);
z_plus.ps_point::operator=(this->z_);
z_fwd.ps_point::operator=(this->z_);
} else {
this->z_.ps_point::operator=(z_minus);
// Extend the current trajectory backwards
this->z_.ps_point::operator=(z_bck);
rho_fwd = rho;
valid_subtree
= build_tree(this->depth_, z_propose,
p_sharp_dummy, p_sharp_minus, rho_subtree,
p_sharp_bck_fwd, p_sharp_bck_bck,
rho_bck, p_bck_fwd, p_bck_bck,
H0, -1, n_leapfrog,
log_sum_weight_subtree, sum_metro_prob,
logger);
z_minus.ps_point::operator=(this->z_);
z_bck.ps_point::operator=(this->z_);
}

if (!valid_subtree) break;

// Sample from an accepted subtree
// Sample from accepted subtree
++(this->depth_);

if (log_sum_weight_subtree > log_sum_weight) {
Expand All @@ -145,9 +169,30 @@ namespace stan {
log_sum_weight
= math::log_sum_exp(log_sum_weight, log_sum_weight_subtree);

// Break when NUTS criterion is no longer satisfied
rho += rho_subtree;
if (!compute_criterion(p_sharp_minus, p_sharp_plus, rho))
// Break when no-u-turn criterion is no longer satisfied
rho = rho_bck + rho_fwd;

// Demand satisfaction around merged subtrees
bool persist_criterion =
compute_criterion(p_sharp_bck_bck,
p_sharp_fwd_fwd,
rho);

// Demand satisfaction between subtrees
Eigen::VectorXd rho_extended = rho_bck + p_fwd_bck;

persist_criterion &=
compute_criterion(p_sharp_bck_bck,
p_sharp_fwd_bck,
rho_extended);

rho_extended = rho_fwd + p_bck_fwd;
persist_criterion &=
compute_criterion(p_sharp_bck_fwd,
p_sharp_fwd_fwd,
rho_extended);

if (!persist_criterion)
break;
}

Expand Down Expand Up @@ -193,9 +238,11 @@ namespace stan {
*
* @param depth Depth of the desired subtree
* @param z_propose State proposed from subtree
* @param p_sharp_left p_sharp from left boundary of returned tree
* @param p_sharp_right p_sharp from the right boundary of returned tree
* @param p_sharp_beg Sharp momentum at beginning of new tree
* @param p_sharp_end Sharp momentum at end of new tree
* @param rho Summed momentum across trajectory
* @param p_beg Momentum at beginning of returned tree
* @param p_end Momentum at end of returned tree
* @param H0 Hamiltonian of initial state
* @param sign Direction in time to built subtree
* @param n_leapfrog Summed number of leapfrog evaluations
Expand All @@ -204,9 +251,11 @@ namespace stan {
* @param logger Logger for messages
*/
bool build_tree(int depth, ps_point& z_propose,
Eigen::VectorXd& p_sharp_left,
Eigen::VectorXd& p_sharp_right,
Eigen::VectorXd& p_sharp_beg,
Eigen::VectorXd& p_sharp_end,
Eigen::VectorXd& rho,
Eigen::VectorXd& p_beg,
Eigen::VectorXd& p_end,
double H0, double sign, int& n_leapfrog,
double& log_sum_weight, double& sum_metro_prob,
callbacks::logger& logger) {
Expand All @@ -231,63 +280,90 @@ namespace stan {
sum_metro_prob += std::exp(H0 - h);

z_propose = this->z_;
rho += this->z_.p;

p_sharp_left = this->hamiltonian_.dtau_dp(this->z_);
p_sharp_right = p_sharp_left;
p_sharp_beg = this->hamiltonian_.dtau_dp(this->z_);
p_sharp_end = p_sharp_beg;

rho += this->z_.p;
p_beg = this->z_.p;
p_end = p_beg;

return !this->divergent_;
}
// General recursion
Eigen::VectorXd p_sharp_dummy(this->z_.p.size());

// Build the left subtree
double log_sum_weight_left = -std::numeric_limits<double>::infinity();
Eigen::VectorXd rho_left = Eigen::VectorXd::Zero(rho.size());
// Build the initial subtree
double log_sum_weight_init = -std::numeric_limits<double>::infinity();

// Momentum and sharp momentum at end of the initial subtree
Eigen::VectorXd p_init_end(this->z_.p.size());
Eigen::VectorXd p_sharp_init_end(this->z_.p.size());

Eigen::VectorXd rho_init = Eigen::VectorXd::Zero(rho.size());

bool valid_left
bool valid_init
= build_tree(depth - 1, z_propose,
p_sharp_left, p_sharp_dummy, rho_left,
p_sharp_beg, p_sharp_init_end,
rho_init, p_beg, p_init_end,
H0, sign, n_leapfrog,
log_sum_weight_left, sum_metro_prob,
log_sum_weight_init, sum_metro_prob,
logger);

if (!valid_left) return false;
if (!valid_init) return false;

// Build the right subtree
ps_point z_propose_right(this->z_);
// Build the final subtree
ps_point z_propose_final(this->z_);

double log_sum_weight_right = -std::numeric_limits<double>::infinity();
Eigen::VectorXd rho_right = Eigen::VectorXd::Zero(rho.size());
double log_sum_weight_final = -std::numeric_limits<double>::infinity();

bool valid_right
= build_tree(depth - 1, z_propose_right,
p_sharp_dummy, p_sharp_right, rho_right,
// Momentum and sharp momentum at beginning of the final subtree
Eigen::VectorXd p_final_beg(this->z_.p.size());
Eigen::VectorXd p_sharp_final_beg(this->z_.p.size());

Eigen::VectorXd rho_final = Eigen::VectorXd::Zero(rho.size());

bool valid_final
= build_tree(depth - 1, z_propose_final,
p_sharp_final_beg, p_sharp_end,
rho_final, p_final_beg, p_end,
H0, sign, n_leapfrog,
log_sum_weight_right, sum_metro_prob,
log_sum_weight_final, sum_metro_prob,
logger);

if (!valid_right) return false;
if (!valid_final) return false;

// Multinomial sample from right subtree
double log_sum_weight_subtree
= math::log_sum_exp(log_sum_weight_left, log_sum_weight_right);
= math::log_sum_exp(log_sum_weight_init, log_sum_weight_final);
log_sum_weight
= math::log_sum_exp(log_sum_weight, log_sum_weight_subtree);

if (log_sum_weight_right > log_sum_weight_subtree) {
z_propose = z_propose_right;
if (log_sum_weight_final > log_sum_weight_subtree) {
z_propose = z_propose_final;
} else {
double accept_prob
= std::exp(log_sum_weight_right - log_sum_weight_subtree);
= std::exp(log_sum_weight_final - log_sum_weight_subtree);
if (this->rand_uniform_() < accept_prob)
z_propose = z_propose_right;
z_propose = z_propose_final;
}

Eigen::VectorXd rho_subtree = rho_left + rho_right;
Eigen::VectorXd rho_subtree = rho_init + rho_final;
rho += rho_subtree;

return compute_criterion(p_sharp_left, p_sharp_right, rho_subtree);
// Demand satisfaction around merged subtrees
bool persist_criterion =
compute_criterion(p_sharp_beg, p_sharp_end, rho_subtree);

// Demand satisfaction between subtrees
rho_subtree = rho_init + p_final_beg;
persist_criterion &=
compute_criterion(p_sharp_beg, p_sharp_final_beg, rho_subtree);

rho_subtree = rho_final + p_init_end;
persist_criterion &=
compute_criterion(p_sharp_init_end, p_sharp_end, rho_subtree);

return persist_criterion;
}

int depth_;
Expand Down
10 changes: 5 additions & 5 deletions src/test/performance/logistic_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,13 +108,13 @@ TEST_F(performance, values_from_tagged_version) {
<< "last tagged version, 2.17.0, had " << N_values << " elements";

std::vector<double> first_run = last_draws_per_run[0];
EXPECT_FLOAT_EQ(-65.781998, first_run[0])
EXPECT_FLOAT_EQ(-66.222504, first_run[0])
<< "lp__: index 0";

EXPECT_FLOAT_EQ(1.0, first_run[1])
<< "accept_stat__: index 1";

EXPECT_FLOAT_EQ(0.76853198, first_run[2])
EXPECT_FLOAT_EQ(1.00182, first_run[2])
<< "stepsize__: index 2";

EXPECT_FLOAT_EQ(2, first_run[3])
Expand All @@ -126,13 +126,13 @@ TEST_F(performance, values_from_tagged_version) {
EXPECT_FLOAT_EQ(0, first_run[5])
<< "divergent__: index 5";

EXPECT_FLOAT_EQ(66.6695, first_run[6])
EXPECT_FLOAT_EQ(67.6213, first_run[6])
<< "energy__: index 6";

EXPECT_FLOAT_EQ(1.55186, first_run[7])
EXPECT_FLOAT_EQ(1.36304, first_run[7])
<< "beta.1: index 7";

EXPECT_FLOAT_EQ(-0.52400702, first_run[8])
EXPECT_FLOAT_EQ(-0.86114103, first_run[8])
<< "beta.2: index 8";

matches_tagged_version = !HasNonfatalFailure();
Expand Down
Loading