Skip to content

Commit

Permalink
Merge pull request #3318 from stan-dev/fix/chainset-quantiles-check-d…
Browse files Browse the repository at this point in the history
…raws
  • Loading branch information
WardBrian authored Nov 29, 2024
2 parents 595ae41 + 5c13934 commit 185f368
Show file tree
Hide file tree
Showing 5 changed files with 146 additions and 5 deletions.
6 changes: 5 additions & 1 deletion src/stan/io/stan_csv_reader.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,11 @@ class stan_csv_reader {
for (int col = 0; col < cols; col++) {
std::getline(ls, line, ',');
boost::trim(line);
std::stringstream(line) >> samples(row, col);
try {
samples(row, col) = static_cast<double>(std::stold(line));
} catch (const std::out_of_range& e) {
samples(row, col) = std::numeric_limits<double>::quiet_NaN();
}
}
}
}
Expand Down
26 changes: 22 additions & 4 deletions src/stan/mcmc/chainset.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -288,21 +288,27 @@ class chainset {

/**
* Compute the quantile value of the specified parameter
* at the specified probability.
* at the specified probability via a call to stan::math::quantile.
*
* Calls stan::math::quantile which throws
* std::invalid_argument If any element of samples_vec is NaN, or size 0.
* and std::domain_error If `p<0` or `p>1`.
* If this happens, error will be caught, quantile value is NaN.
*
* @param index parameter index
* @param prob probability
* @return parameter value at quantile
*/
double quantile(const int index, const double prob) const {
// Ensure the probability is within [0, 1]
Eigen::MatrixXd draws = samples(index);
Eigen::Map<Eigen::VectorXd> map(draws.data(), draws.size());
return stan::math::quantile(map, prob);
double result;
try {
result = stan::math::quantile(map, prob);
} catch (const std::logic_error& e) {
return std::numeric_limits<double>::quiet_NaN();
}
return result;
}

/**
Expand All @@ -321,6 +327,11 @@ class chainset {
* Compute the quantile values of the specified parameter
* for a set of specified probabilities.
*
* Calls stan::math::quantile which throws
* std::invalid_argument If any element of samples_vec is NaN, or size 0.
* and std::domain_error If `p<0` or `p>1`.
* If this happens, error will be caught, quantile value is NaN.
*
* @param index parameter index
* @param probs vector of probabilities
* @return vector of parameter values for quantiles
Expand All @@ -332,7 +343,14 @@ class chainset {
Eigen::MatrixXd draws = samples(index);
Eigen::Map<Eigen::VectorXd> map(draws.data(), draws.size());
std::vector<double> probs_vec(probs.data(), probs.data() + probs.size());
std::vector<double> quantiles = stan::math::quantile(map, probs_vec);
std::vector<double> quantiles;
try {
quantiles = stan::math::quantile(map, probs_vec);
} catch (const std::logic_error& e) {
Eigen::VectorXd nans(probs.size());
nans.setConstant(std::numeric_limits<double>::quiet_NaN());
return nans;
}
return Eigen::Map<Eigen::VectorXd>(quantiles.data(), quantiles.size());
}

Expand Down
11 changes: 11 additions & 0 deletions src/test/unit/io/stan_csv_reader_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -604,3 +604,14 @@ TEST_F(StanIoStanCsvReader, variational) {
ASSERT_EQ(1000, variational.metadata.num_samples);
ASSERT_EQ(0, variational.adaptation.metric.size());
}

TEST_F(StanIoStanCsvReader, read_nans) {
std::ifstream datagen_stream;
datagen_stream.open("src/test/unit/mcmc/test_csv_files/datagen_output.csv",
std::ifstream::in);
std::stringstream out;
stan::io::stan_csv datagen
= stan::io::stan_csv_reader::parse(datagen_stream, &out);
datagen_stream.close();
ASSERT_TRUE(std::isnan(datagen.samples(0, 2)));
}
26 changes: 26 additions & 0 deletions src/test/unit/mcmc/chainset_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,3 +207,29 @@ TEST_F(McmcChains, summary_stats) {
EXPECT_NEAR(theta_ac(i), theta_ac_expect(i), 0.0005);
}
}

TEST_F(McmcChains, quantile_tests) {
std::ifstream datagen_stream;
datagen_stream.open("src/test/unit/mcmc/test_csv_files/datagen_output.csv",
std::ifstream::in);
stan::io::stan_csv datagen_csv
= stan::io::stan_csv_reader::parse(datagen_stream, &out);
datagen_stream.close();
stan::mcmc::chainset datagen_chains(datagen_csv);

Eigen::VectorXd probs(6);
probs << 0.0, 0.01, 0.05, 0.95, 0.99, 1.0;
Eigen::VectorXd stepsize_quantiles
= datagen_chains.quantiles("stepsize__", probs);
for (size_t i = 0; i < probs.size(); ++i) {
EXPECT_TRUE(std::isnan(stepsize_quantiles(i)));
}

Eigen::VectorXd bad_probs(3);
bad_probs << 5, 50, 95;
Eigen::VectorXd y_sim_quantiles
= datagen_chains.quantiles("y_sim[1]", bad_probs);
for (size_t i = 0; i < bad_probs.size(); ++i) {
EXPECT_TRUE(std::isnan(y_sim_quantiles(i)));
}
}
Loading

0 comments on commit 185f368

Please sign in to comment.