diff --git a/docs/protos.html b/docs/protos.html
index a7010afa1..7263f6160 100644
--- a/docs/protos.html
+++ b/docs/protos.html
@@ -457,10 +457,6 @@
Table of Contents
MTruncSBPrior.DPPrior
-
- MTruncSBPrior.MFMPrior
-
-
MTruncSBPrior.PYPrior
@@ -2369,24 +2365,18 @@ TruncSBPrior
- mfm_prior |
- TruncSBPrior.MFMPrior |
- |
- |
-
-
-
- pm_prior |
- PythonMixPrior |
+ num_components |
+ uint32 |
|
- |
+ Number of components in the process |
- num_components |
- uint32 |
+ infinite_mixture |
+ bool |
|
- Number of components in the process |
+ If true we must use the Slice Sampler, and num_components is used only for
+the initialization |
@@ -2444,30 +2434,6 @@ TruncSBPrior.DPPrior
- TruncSBPrior.MFMPrior
-
-
-
-
-
- Field | Type | Label | Description |
-
-
-
-
- totalmass |
- double |
- |
- Truncated Dirichlet process |
-
-
-
-
-
-
-
-
-
TruncSBPrior.PYPrior
diff --git a/src/hierarchies/CMakeLists.txt b/src/hierarchies/CMakeLists.txt
index 42a45fd87..8426cae04 100644
--- a/src/hierarchies/CMakeLists.txt
+++ b/src/hierarchies/CMakeLists.txt
@@ -8,6 +8,7 @@ target_sources(bayesmix
lin_reg_uni_hierarchy.h
fa_hierarchy.h
lapnig_hierarchy.h
+ betagg_hierarchy.h
)
add_subdirectory(likelihoods)
diff --git a/src/hierarchies/betagg_hierarchy.h b/src/hierarchies/betagg_hierarchy.h
new file mode 100644
index 000000000..c2e487679
--- /dev/null
+++ b/src/hierarchies/betagg_hierarchy.h
@@ -0,0 +1,57 @@
+#ifndef BAYESMIX_HIERARCHIES_BETA_GG_HIERARCHY_H_
+#define BAYESMIX_HIERARCHIES_BETA_GG_HIERARCHY_H_
+
+#include "base_hierarchy.h"
+#include "hierarchy_id.pb.h"
+#include "likelihoods/beta_likelihood.h"
+#include "priors/gamma_gamma_prior.h"
+#include "updaters/random_walk_updater.h"
+
+/**
+ * Beta Gamma-Gamma hierarchy for univaraite data in [0, 1]
+ *
+ * This class represents a hierarchical model where data are distributed
+ * according to a Beta likelihood (see the `BetaLikelihood` class for
+ * details). The shape and rate parameters of the likelihood have
+ * independent gamma priors. That is
+ *
+ * \f[
+ * f(x_i \mid \alpha, \beta) &= Beta(\alpha, \beta) \\
+ * \alpha &\sim Gamma(\alpha_a, \alpha_b) \\
+ * \beta &\sim Gamma(\beta_a, \beta_b)
+ * \f]
+ *
+ * The state is composed of shape and rate. Note that this hierarchy
+ * is NOT conjugate, meaning that the marginal distribution is not available
+ * in closed form.
+ */
+
+class BetaGGHierarchy
+ : public BaseHierarchy {
+ public:
+ BetaGGHierarchy() = default;
+ ~BetaGGHierarchy() = default;
+
+ //! Returns the Protobuf ID associated to this class
+ bayesmix::HierarchyId get_id() const override {
+ return bayesmix::HierarchyId::BetaGG;
+ }
+
+ //! Sets the default updater algorithm for this hierarchy
+ void set_default_updater() {
+ updater = std::make_shared(0.1);
+ }
+
+ //! Initializes state parameters to appropriate values
+ void initialize_state() override {
+ // Get hypers
+ auto hypers = prior->get_hypers();
+ // Initialize likelihood state
+ State::ShapeRate state;
+ state.shape = hypers.a_shape / hypers.a_rate;
+ state.rate = hypers.b_shape / hypers.b_rate;
+ like->set_state(state);
+ };
+};
+
+#endif // BAYESMIX_HIERARCHIES_BETA_GG_HIERARCHY_H_
diff --git a/src/hierarchies/likelihoods/CMakeLists.txt b/src/hierarchies/likelihoods/CMakeLists.txt
index df10e8674..15dea96c0 100644
--- a/src/hierarchies/likelihoods/CMakeLists.txt
+++ b/src/hierarchies/likelihoods/CMakeLists.txt
@@ -12,6 +12,8 @@ target_sources(bayesmix PUBLIC
laplace_likelihood.cc
fa_likelihood.h
fa_likelihood.cc
+ beta_likelihood.h
+ beta_likelihood.cc
)
add_subdirectory(states)
diff --git a/src/hierarchies/likelihoods/beta_likelihood.cc b/src/hierarchies/likelihoods/beta_likelihood.cc
new file mode 100644
index 000000000..ddd08df13
--- /dev/null
+++ b/src/hierarchies/likelihoods/beta_likelihood.cc
@@ -0,0 +1,22 @@
+#include "beta_likelihood.h"
+
+double BetaLikelihood::compute_lpdf(const Eigen::RowVectorXd &datum) const {
+ return stan::math::beta_lpdf(datum(0), state.shape, state.rate);
+}
+
+void BetaLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum,
+ bool add) {
+ double x = datum(0);
+ if (add) {
+ sum_logs += std::log(x);
+ sum_logs1m += std::log(1. - x);
+ } else {
+ sum_logs -= std::log(x);
+ sum_logs1m -= std::log(1. - x);
+ }
+}
+
+void BetaLikelihood::clear_summary_statistics() {
+ sum_logs = 0;
+ sum_logs1m = 0;
+}
diff --git a/src/hierarchies/likelihoods/beta_likelihood.h b/src/hierarchies/likelihoods/beta_likelihood.h
new file mode 100644
index 000000000..d99a2aa84
--- /dev/null
+++ b/src/hierarchies/likelihoods/beta_likelihood.h
@@ -0,0 +1,56 @@
+#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_BETA_LIKELIHOOD_H_
+#define BAYESMIX_HIERARCHIES_LIKELIHOODS_BETA_LIKELIHOOD_H_
+
+#include
+
+#include
+#include
+#include
+
+#include "algorithm_state.pb.h"
+#include "base_likelihood.h"
+#include "states/includes.h"
+
+/**
+ * A univariate Beta likelihood, using the `State::ShapeRate` state. Represents
+ * the model:
+ *
+ * \f[
+ * y_1,\dots,y_k \mid \mu, \sigma^2 \stackrel{\small\mathrm{iid}}{\sim}
+ * Beta(a, b),
+ * \f]
+ */
+
+class BetaLikelihood
+ : public BaseLikelihood {
+ public:
+ BetaLikelihood() = default;
+ ~BetaLikelihood() = default;
+ bool is_multivariate() const override { return false; };
+ bool is_dependent() const override { return false; };
+ void clear_summary_statistics() override;
+
+ template
+ T cluster_lpdf_from_unconstrained(
+ const Eigen::Matrix &unconstrained_params) const {
+ assert(unconstrained_params.size() == 2);
+
+ T a = stan::math::positive_constrain(unconstrained_params(0));
+ T b = stan::math::positive_constrain(unconstrained_params(1));
+
+ T out = 0.;
+
+ return (a - 1.) * sum_logs + (b - 1.) * sum_logs1m -
+ card * (stan::math::lgamma(a) + stan::math::lgamma(b) -
+ stan::math::lgamma(a + b));
+ }
+
+ protected:
+ double compute_lpdf(const Eigen::RowVectorXd &datum) const override;
+ void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) override;
+
+ double sum_logs = 0;
+ double sum_logs1m = 0;
+};
+
+#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_BETA_LIKELIHOOD_H_
diff --git a/src/hierarchies/likelihoods/states/includes.h b/src/hierarchies/likelihoods/states/includes.h
index f4f868c52..850bc25c8 100644
--- a/src/hierarchies/likelihoods/states/includes.h
+++ b/src/hierarchies/likelihoods/states/includes.h
@@ -3,6 +3,7 @@
#include "fa_state.h"
#include "multi_ls_state.h"
+#include "shape_rate_state.h"
#include "uni_lin_reg_ls_state.h"
#include "uni_ls_state.h"
diff --git a/src/hierarchies/likelihoods/states/shape_rate_state.h b/src/hierarchies/likelihoods/states/shape_rate_state.h
new file mode 100644
index 000000000..c1ffbcc2a
--- /dev/null
+++ b/src/hierarchies/likelihoods/states/shape_rate_state.h
@@ -0,0 +1,88 @@
+#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_SHAPE_RATE_STATE_H_
+#define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_SHAPE_RATE_STATE_H_
+
+#include
+#include
+
+#include "algorithm_state.pb.h"
+#include "base_state.h"
+#include "src/utils/proto_utils.h"
+
+namespace State {
+
+//! Returns the constrained parametrization from the
+//! unconstrained one, i.e. [in[0], exp(in[1])]
+template
+Eigen::Matrix shape_rate_to_constrained(
+ Eigen::Matrix in) {
+ Eigen::Matrix out(2);
+ out << stan::math::exp(in(0)), stan::math::exp(in(1));
+ return out;
+}
+
+//! Returns the unconstrained parametrization from the
+//! constrained one, i.e. [log(in[0]), log(in[1])]
+template
+Eigen::Matrix shape_rate_to_unconstrained(
+ Eigen::Matrix in) {
+ Eigen::Matrix out(2);
+ out << stan::math::log(in(0)), stan::math::log(in(1));
+ return out;
+}
+
+//! Returns the log determinant of the jacobian of the map
+//! (x, y) -> (log(x), log(y)), that is the inverse map of the
+//! constrained -> unconstrained representation.
+template
+T shape_rate_log_det_jac(Eigen::Matrix constrained) {
+ T out = 0;
+ stan::math::positive_constrain(stan::math::log(constrained(0)), out);
+ stan::math::positive_constrain(stan::math::log(constrained(1)), out);
+ return out;
+}
+
+//! A univariate shape-rate state
+//! The unconstrained representation corresponds to (log(shape), log(rate))
+class ShapeRate : public BaseState {
+ public:
+ double shape, rate;
+
+ using ProtoState = bayesmix::AlgorithmState::ClusterState;
+
+ Eigen::VectorXd get_unconstrained() const override {
+ Eigen::VectorXd temp(2);
+ temp << shape, rate;
+ return shape_rate_to_unconstrained(temp);
+ }
+
+ void set_from_unconstrained(const Eigen::VectorXd &in) override {
+ Eigen::VectorXd temp = shape_rate_to_constrained(in);
+ shape = temp(0);
+ rate = temp(1);
+ }
+
+ void set_from_proto(const ProtoState &state_, bool update_card) override {
+ if (update_card) {
+ card = state_.cardinality();
+ }
+ shape = state_.sr_state().shape();
+ rate = state_.sr_state().rate();
+ }
+
+ ProtoState get_as_proto() const override {
+ ProtoState state;
+ state.mutable_sr_state()->set_shape(shape);
+ state.mutable_sr_state()->set_rate(rate);
+ return state;
+ }
+
+ double log_det_jac() const override {
+ Eigen::VectorXd temp(2);
+ temp << shape, rate;
+ return shape_rate_log_det_jac(temp);
+ }
+};
+
+} // namespace State
+
+#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_SHAPE_RATE_STATE_H_
diff --git a/src/hierarchies/load_hierarchies.h b/src/hierarchies/load_hierarchies.h
index 21ec46ae5..61c8fb18a 100644
--- a/src/hierarchies/load_hierarchies.h
+++ b/src/hierarchies/load_hierarchies.h
@@ -5,6 +5,7 @@
#include
#include "abstract_hierarchy.h"
+#include "betagg_hierarchy.h"
#include "fa_hierarchy.h"
#include "hierarchy_id.pb.h"
#include "lapnig_hierarchy.h"
@@ -43,6 +44,9 @@ __attribute__((constructor)) static void load_hierarchies() {
Builder LapNIGbuilder = []() {
return std::make_shared();
};
+ Builder BetaGGbuilder = []() {
+ return std::make_shared();
+ };
factory.add_builder(NNIGHierarchy().get_id(), NNIGbuilder);
factory.add_builder(NNxIGHierarchy().get_id(), NNxIGbuilder);
@@ -50,6 +54,7 @@ __attribute__((constructor)) static void load_hierarchies() {
factory.add_builder(LinRegUniHierarchy().get_id(), LinRegUnibuilder);
factory.add_builder(FAHierarchy().get_id(), FAbuilder);
factory.add_builder(LapNIGHierarchy().get_id(), LapNIGbuilder);
+ factory.add_builder(BetaGGHierarchy().get_id(), BetaGGbuilder);
}
#endif // BAYESMIX_HIERARCHIES_LOAD_HIERARCHIES_H_
diff --git a/src/hierarchies/priors/CMakeLists.txt b/src/hierarchies/priors/CMakeLists.txt
index d6901ee65..53ad75518 100644
--- a/src/hierarchies/priors/CMakeLists.txt
+++ b/src/hierarchies/priors/CMakeLists.txt
@@ -13,4 +13,6 @@ target_sources(bayesmix PUBLIC
mnig_prior_model.cc
fa_prior_model.h
fa_prior_model.cc
+ gamma_gamma_prior.h
+ gamma_gamma_prior.cc
)
diff --git a/src/hierarchies/priors/gamma_gamma_prior.cc b/src/hierarchies/priors/gamma_gamma_prior.cc
new file mode 100644
index 000000000..9e6c58d20
--- /dev/null
+++ b/src/hierarchies/priors/gamma_gamma_prior.cc
@@ -0,0 +1,67 @@
+#include "gamma_gamma_prior.h"
+
+double GGPriorModel::lpdf(const google::protobuf::Message &state_) {
+ // Downcast state
+ auto &state = downcast_state(state_).sr_state();
+ double target = 0.;
+ target +=
+ stan::math::gamma_lpdf(state.shape(), hypers->a_shape, hypers->a_rate);
+ target +=
+ stan::math::gamma_lpdf(state.rate(), hypers->b_shape, hypers->b_rate);
+ return target;
+}
+
+State::ShapeRate GGPriorModel::sample(ProtoHypersPtr hier_hypers) {
+ // Random seed
+ auto &rng = bayesmix::Rng::Instance().get();
+
+ // Get params to use
+ auto params = get_hypers_proto()->gg_state();
+ State::ShapeRate out;
+ out.shape = stan::math::gamma_rng(params.a_shape(), params.a_rate(), rng);
+ out.rate = stan::math::gamma_rng(params.b_shape(), params.b_rate(), rng);
+ return out;
+}
+
+void GGPriorModel::update_hypers(
+ const std::vector &states) {
+ auto &rng = bayesmix::Rng::Instance().get();
+ if (prior->has_fixed_values()) {
+ return;
+ } else {
+ throw std::invalid_argument("Unrecognized hierarchy prior");
+ }
+}
+
+void GGPriorModel::set_hypers_from_proto(
+ const google::protobuf::Message &hypers_) {
+ auto &hyperscast = downcast_hypers(hypers_).gg_state();
+ hypers->a_shape = hyperscast.a_shape();
+ hypers->a_rate = hyperscast.a_rate();
+ hypers->b_shape = hyperscast.b_shape();
+ hypers->b_rate = hyperscast.b_rate();
+}
+
+std::shared_ptr
+GGPriorModel::get_hypers_proto() const {
+ bayesmix::GamGamDistribution hypers_;
+ hypers_.set_a_shape(hypers->a_shape);
+ hypers_.set_a_rate(hypers->a_rate);
+ hypers_.set_b_shape(hypers->b_shape);
+ hypers_.set_b_rate(hypers->b_rate);
+ auto out = std::make_shared();
+ out->mutable_gg_state()->CopyFrom(hypers_);
+ return out;
+}
+
+void GGPriorModel::initialize_hypers() {
+ if (prior->has_fixed_values()) {
+ // Set values
+ hypers->a_shape = prior->fixed_values().a_shape();
+ hypers->a_rate = prior->fixed_values().a_rate();
+ hypers->b_shape = prior->fixed_values().b_shape();
+ hypers->b_rate = prior->fixed_values().b_rate();
+ } else {
+ throw std::invalid_argument("Unrecognized hierarchy prior");
+ }
+}
diff --git a/src/hierarchies/priors/gamma_gamma_prior.h b/src/hierarchies/priors/gamma_gamma_prior.h
new file mode 100644
index 000000000..9ce995771
--- /dev/null
+++ b/src/hierarchies/priors/gamma_gamma_prior.h
@@ -0,0 +1,64 @@
+#ifndef BAYESMIX_HIERARCHIES_PRIORS_GG_PRIOR_MODEL_H_
+#define BAYESMIX_HIERARCHIES_PRIORS_GG_PRIOR_MODEL_H_
+
+#include
+#include
+#include
+
+#include "base_prior_model.h"
+#include "hierarchy_prior.pb.h"
+#include "hyperparams.h"
+#include "src/utils/rng.h"
+
+/*
+ * Prior model for `ShapeRate` states.
+ * This class assumes that the shape and rate are independent and given
+ * Gamma-distributed priors
+ */
+class GGPriorModel
+ : public BasePriorModel {
+ public:
+ using AbstractPriorModel::ProtoHypers;
+ using AbstractPriorModel::ProtoHypersPtr;
+
+ GGPriorModel() = default;
+ ~GGPriorModel() = default;
+
+ double lpdf(const google::protobuf::Message &state_) override;
+
+ template
+ T lpdf_from_unconstrained(
+ const Eigen::Matrix &unconstrained_params) const {
+ Eigen::Matrix constrained_params =
+ State::shape_rate_to_constrained(unconstrained_params);
+ // std::cout << "constrained_params: " << constrained_params << std::endl;
+ T log_det_jac = State::shape_rate_log_det_jac(constrained_params);
+ T shape = constrained_params(0);
+ T rate = constrained_params(1);
+ T lpdf = stan::math::gamma_lpdf(shape, hypers->a_shape, hypers->a_rate) +
+ stan::math::gamma_lpdf(rate, hypers->a_shape, hypers->a_rate);
+
+ return lpdf + log_det_jac;
+ }
+
+ State::ShapeRate sample(ProtoHypersPtr hier_hypers = nullptr) override;
+
+ void update_hypers(const std::vector
+ &states) override;
+
+ void set_hypers_from_proto(
+ const google::protobuf::Message &hypers_) override;
+
+ unsigned int get_dim() const { return dim; };
+
+ std::shared_ptr get_hypers_proto()
+ const override;
+
+ protected:
+ void initialize_hypers() override;
+
+ unsigned int dim;
+};
+
+#endif // BAYESMIX_HIERARCHIES_PRIORS_GG_PRIOR_MODEL_H_
diff --git a/src/hierarchies/priors/hyperparams.h b/src/hierarchies/priors/hyperparams.h
index 1aca6dc4a..ccc6270c2 100644
--- a/src/hierarchies/priors/hyperparams.h
+++ b/src/hierarchies/priors/hyperparams.h
@@ -31,6 +31,11 @@ struct FA {
unsigned int q;
};
+struct GG {
+ double a_shape, a_rate;
+ double b_shape, b_rate;
+};
+
} // namespace Hyperparams
#endif // BAYESMIX_HIERARCHIES_PRIORS_HYPERPARAMS_H_
diff --git a/src/proto/algorithm_state.proto b/src/proto/algorithm_state.proto
index 02f6ba07d..ad8b4e641 100644
--- a/src/proto/algorithm_state.proto
+++ b/src/proto/algorithm_state.proto
@@ -26,7 +26,7 @@ message AlgorithmState {
LinRegUniLSState lin_reg_uni_ls_state = 4; // State of a linear regression univariate location-scale family
Vector general_state = 5; // Just a vector of doubles
FAState fa_state = 6; // State of a Mixture of Factor Analysers
-
+ SRState sr_state = 7; // State of a Mixture of Beta distributions
}
int32 cardinality = 3; // How many observations are in this cluster
}
@@ -45,6 +45,7 @@ message AlgorithmState {
MultiNormalIGDistribution lin_reg_uni_state = 4;
NxIGDistribution nnxig_state = 5;
FAPriorDistribution fa_state = 7;
+ GamGamDistribution gg_state = 8;
}
}
HierarchyHypers hierarchy_hypers = 5; // The current values of the hyperparameters of the hierarchy
diff --git a/src/proto/distribution.proto b/src/proto/distribution.proto
index 804813f78..785e2f3ae 100644
--- a/src/proto/distribution.proto
+++ b/src/proto/distribution.proto
@@ -95,3 +95,14 @@ message MultiNormalIGDistribution {
double shape = 3;
double scale = 4;
}
+
+/*
+ * Parameters for the product of two independent Gamma distributions with different
+ * rates and shapes
+ */
+message GamGamDistribution {
+ double a_shape = 1;
+ double a_rate = 2;
+ double b_shape = 3;
+ double b_rate = 4;
+}
diff --git a/src/proto/hierarchy_id.proto b/src/proto/hierarchy_id.proto
index 45e9b4d6d..710c2e075 100644
--- a/src/proto/hierarchy_id.proto
+++ b/src/proto/hierarchy_id.proto
@@ -14,4 +14,5 @@ enum HierarchyId {
FA = 5; // Factor Analysers
NNxIG = 6; // Normal - Normal x Inverse Gamma
PythonHier = 7; // Generic python hierarchy
+ BetaGG = 8; // Beta - GammaGamma hierarchy
}
diff --git a/src/proto/hierarchy_prior.proto b/src/proto/hierarchy_prior.proto
index 09558e0d5..7722e29c3 100644
--- a/src/proto/hierarchy_prior.proto
+++ b/src/proto/hierarchy_prior.proto
@@ -123,3 +123,12 @@ message PythonHierPrior{
Vector values = 1; // values are modified from python
}
}
+
+/*
+ * Prior for the parameters of the base measure in a Beta - GammaGamma hierarchy
+ */
+message GGPrior {
+ oneof prior {
+ GamGamDistribution fixed_values = 1;
+ }
+}
diff --git a/src/proto/ls_state.proto b/src/proto/ls_state.proto
index 50028afde..6794185c5 100644
--- a/src/proto/ls_state.proto
+++ b/src/proto/ls_state.proto
@@ -37,3 +37,11 @@ message FAState {
Matrix eta = 3;
Matrix lambda = 4;
}
+
+/*
+ * Parameters of a shape-rate state, used, e.g., by the BetaLikelihood
+ */
+message SRState {
+ double shape = 1;
+ double rate = 2;
+}
diff --git a/test/hierarchies.cc b/test/hierarchies.cc
index 0a4a94c25..2c62c1a22 100644
--- a/test/hierarchies.cc
+++ b/test/hierarchies.cc
@@ -5,6 +5,7 @@
#include "algorithm_state.pb.h"
#include "ls_state.pb.h"
+#include "src/hierarchies/betagg_hierarchy.h"
#include "src/hierarchies/fa_hierarchy.h"
#include "src/hierarchies/lin_reg_uni_hierarchy.h"
#include "src/hierarchies/nnig_hierarchy.h"
@@ -373,3 +374,44 @@ TEST(fa_hierarchy, sample_given_data) {
ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString());
}
+
+TEST(betagg_hierarchy, sample_given_data) {
+ auto hier = std::make_shared();
+ bayesmix::GGPrior prior;
+
+ prior.mutable_fixed_values()->set_a_rate(2.0);
+ prior.mutable_fixed_values()->set_a_shape(2.0);
+ prior.mutable_fixed_values()->set_b_rate(2.0);
+ prior.mutable_fixed_values()->set_b_shape(2.0);
+ hier->get_mutable_prior()->CopyFrom(prior);
+
+ auto mala_updater = std::make_shared(0.0005);
+ hier->set_updater(mala_updater);
+
+ Eigen::MatrixXd dataset = Eigen::MatrixXd::Ones(100, 1) * 0.1;
+ hier->set_dataset(&dataset);
+ hier->initialize();
+
+ for (int i = 0; i < dataset.rows(); i++) {
+ hier->add_datum(i, dataset.row(i));
+ }
+
+ for (int i = 0; i < 1000; i++) {
+ hier->sample_full_cond();
+ }
+
+ bayesmix::AlgorithmState::ClusterState out;
+ hier->write_state_to_proto(&out);
+
+ double a = out.sr_state().shape();
+ double b = out.sr_state().rate();
+
+ double mean = a / (a + b);
+ double var = a * b / ((a + b) * (a + b) * (a + b + 1));
+
+ ASSERT_LT(mean, 0.2);
+ ASSERT_GT(mean, 0.05);
+ ASSERT_LT(var, 0.1);
+
+ // ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString());
+}
diff --git a/test/likelihoods.cc b/test/likelihoods.cc
index 5223f4f6d..d7e116760 100644
--- a/test/likelihoods.cc
+++ b/test/likelihoods.cc
@@ -6,6 +6,7 @@
#include "algorithm_state.pb.h"
#include "ls_state.pb.h"
+#include "src/hierarchies/likelihoods/beta_likelihood.h"
#include "src/hierarchies/likelihoods/laplace_likelihood.h"
#include "src/hierarchies/likelihoods/multi_norm_likelihood.h"
#include "src/hierarchies/likelihoods/uni_lin_reg_likelihood.h"
@@ -389,3 +390,84 @@ TEST(laplace_likelihood, eval_lpdf_unconstrained) {
clus_lpdf = like->cluster_lpdf_from_unconstrained(unconstrained_params);
ASSERT_TRUE(std::abs(clus_lpdf - lpdf) > 1e-5);
}
+
+TEST(beta_likelihood, set_get_state) {
+ // Instance
+ auto like = std::make_shared();
+
+ // Prepare buffers
+ bayesmix::ShapeRateState state_;
+ bayesmix::AlgorithmState::ClusterState set_state_;
+ bayesmix::AlgorithmState::ClusterState got_state_;
+
+ // Prepare state
+ state_.set_shape(5.23);
+ state_.set_rate(1.02);
+ set_state_.mutable_sr_state()->CopyFrom(state_);
+
+ // Set and get the state
+ like->set_state_from_proto(set_state_);
+ like->write_state_to_proto(&got_state_);
+
+ // Check if they coincides
+ ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString());
+}
+
+TEST(beta_likelihood, eval_lpdf) {
+ // Instance
+ auto like = std::make_shared();
+
+ // Set state from proto
+ bayesmix::ShapeRateState state_;
+ bayesmix::AlgorithmState::ClusterState clust_state_;
+ state_.set_shape(5);
+ state_.set_rate(1);
+ clust_state_.mutable_sr_state()->CopyFrom(state_);
+ like->set_state_from_proto(clust_state_);
+
+ // Add new datum to likelihood
+ Eigen::VectorXd data(3);
+ data << 0.5, 0.6, 0.4;
+
+ // Compute lpdf on this grid of points
+ auto evals = like->lpdf_grid(data);
+ auto like_copy = like->clone();
+ auto evals_copy = like_copy->lpdf_grid(data);
+
+ // Check if they coincides
+ ASSERT_EQ(evals, evals_copy);
+}
+
+TEST(beta_likelihood, eval_lpdf_unconstrained) {
+ // Instance
+ auto like = std::make_shared();
+
+ // Set state from proto
+ bayesmix::ShapeRateState state_;
+ bayesmix::AlgorithmState::ClusterState clust_state_;
+ double a = 5;
+ double b = 1;
+ state_.set_shape(a);
+ state_.set_rate(b);
+ Eigen::VectorXd unconstrained_params(2);
+ unconstrained_params << std::log(a), std::log(b);
+ clust_state_.mutable_sr_state()->CopyFrom(state_);
+ like->set_state_from_proto(clust_state_);
+
+ // Add new datum to likelihood
+ Eigen::VectorXd data(3);
+ data << 0.5, 0.6, 0.4;
+ double lpdf = 0.0;
+ for (int i = 0; i < data.size(); ++i) {
+ like->add_datum(i, data.row(i));
+ lpdf += like->lpdf(data.row(i));
+ }
+
+ double clus_lpdf =
+ like->cluster_lpdf_from_unconstrained(unconstrained_params);
+ ASSERT_NEAR(lpdf, clus_lpdf, 1e-12);
+
+ unconstrained_params(0) = 4.0;
+ clus_lpdf = like->cluster_lpdf_from_unconstrained(unconstrained_params);
+ ASSERT_TRUE(std::abs(clus_lpdf - lpdf) > 1e-5);
+}
diff --git a/test/prior_models.cc b/test/prior_models.cc
index fcf380c4a..f0c624370 100644
--- a/test/prior_models.cc
+++ b/test/prior_models.cc
@@ -6,6 +6,7 @@
#include "algorithm_state.pb.h"
#include "hierarchy_prior.pb.h"
+#include "src/hierarchies/priors/gamma_gamma_prior.h"
#include "src/hierarchies/priors/mnig_prior_model.h"
#include "src/hierarchies/priors/nig_prior_model.h"
#include "src/hierarchies/priors/nw_prior_model.h"
@@ -487,3 +488,85 @@ TEST(mnig_prior_model, sample) {
ASSERT_TRUE(state1.get_as_proto().DebugString() !=
state2.get_as_proto().DebugString());
}
+
+TEST(gg_prior_model, set_get_hypers) {
+ // Instance
+ auto prior = std::make_shared();
+
+ // Prepare buffers
+ bayesmix::GamGamDistribution hypers_;
+ bayesmix::AlgorithmState::HierarchyHypers set_state_;
+ bayesmix::AlgorithmState::HierarchyHypers got_state_;
+
+ // Prepare hypers
+ hypers_.set_a_shape(5.0);
+ hypers_.set_a_rate(1.0);
+ hypers_.set_b_shape(4.0);
+ hypers_.set_b_rate(3.0);
+ set_state_.mutable_gg_state()->CopyFrom(hypers_);
+
+ // Set and get hypers
+ prior->set_hypers_from_proto(set_state_);
+ prior->write_hypers_to_proto(&got_state_);
+
+ // Check if they coincides
+ ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString());
+}
+
+TEST(gg_prior_model, fixed_values_prior) {
+ // Prepare buffers
+ bayesmix::GGPrior prior;
+ bayesmix::AlgorithmState::HierarchyHypers prior_out;
+ std::vector> prior_models;
+ std::vector states;
+
+ // Set fixed value prior
+ prior.mutable_fixed_values()->set_a_shape(5.0);
+ prior.mutable_fixed_values()->set_a_rate(1.0);
+ prior.mutable_fixed_values()->set_b_shape(4.0);
+ prior.mutable_fixed_values()->set_b_rate(3.0);
+
+ // Initialize prior model
+ auto prior_model = std::make_shared();
+ prior_model->get_mutable_prior()->CopyFrom(prior);
+ prior_model->initialize();
+
+ // Check equality before update
+ prior_models.push_back(prior_model);
+ for (size_t i = 1; i < 4; i++) {
+ prior_models.push_back(prior_model->clone());
+ prior_models[i]->write_hypers_to_proto(&prior_out);
+ ASSERT_EQ(prior.fixed_values().DebugString(),
+ prior_out.gg_state().DebugString());
+ }
+
+ // Check equality after update
+ prior_models[0]->update_hypers(states);
+ prior_models[0]->write_hypers_to_proto(&prior_out);
+ for (size_t i = 1; i < 4; i++) {
+ prior_models[i]->write_hypers_to_proto(&prior_out);
+ ASSERT_EQ(prior.fixed_values().DebugString(),
+ prior_out.gg_state().DebugString());
+ }
+}
+
+TEST(gg_prior_model, sample) {
+ // Instance
+ auto prior = std::make_shared();
+
+ // Define prior hypers
+ bayesmix::AlgorithmState::HierarchyHypers hypers_proto;
+ hypers_proto.mutable_gg_state()->set_a_shape(5.0);
+ hypers_proto.mutable_gg_state()->set_a_rate(1.0);
+ hypers_proto.mutable_gg_state()->set_b_shape(4.0);
+ hypers_proto.mutable_gg_state()->set_b_rate(3.0);
+
+ // Set hypers and get sampled state as proto
+ prior->set_hypers_from_proto(hypers_proto);
+ auto state1 = prior->sample();
+ auto state2 = prior->sample();
+
+ // Check if they coincides
+ ASSERT_TRUE(state1.get_as_proto().DebugString() !=
+ state2.get_as_proto().DebugString());
+}