diff --git a/docs/protos.html b/docs/protos.html index a7010afa1..7263f6160 100644 --- a/docs/protos.html +++ b/docs/protos.html @@ -457,10 +457,6 @@

Table of Contents

MTruncSBPrior.DPPrior -
  • - MTruncSBPrior.MFMPrior -
  • -
  • MTruncSBPrior.PYPrior
  • @@ -2369,24 +2365,18 @@

    TruncSBPrior

    - mfm_prior - TruncSBPrior.MFMPrior - -

    - - - - pm_prior - PythonMixPrior + num_components + uint32 -

    +

    Number of components in the process

    - num_components - uint32 + infinite_mixture + bool -

    Number of components in the process

    +

    If true we must use the Slice Sampler, and num_components is used only for +the initialization

    @@ -2444,30 +2434,6 @@

    TruncSBPrior.DPPrior

    -

    TruncSBPrior.MFMPrior

    -

    - - - - - - - - - - - - - - - - -
    FieldTypeLabelDescription
    totalmassdouble

    Truncated Dirichlet process

    - - - - -

    TruncSBPrior.PYPrior

    diff --git a/src/hierarchies/CMakeLists.txt b/src/hierarchies/CMakeLists.txt index 42a45fd87..8426cae04 100644 --- a/src/hierarchies/CMakeLists.txt +++ b/src/hierarchies/CMakeLists.txt @@ -8,6 +8,7 @@ target_sources(bayesmix lin_reg_uni_hierarchy.h fa_hierarchy.h lapnig_hierarchy.h + betagg_hierarchy.h ) add_subdirectory(likelihoods) diff --git a/src/hierarchies/betagg_hierarchy.h b/src/hierarchies/betagg_hierarchy.h new file mode 100644 index 000000000..c2e487679 --- /dev/null +++ b/src/hierarchies/betagg_hierarchy.h @@ -0,0 +1,57 @@ +#ifndef BAYESMIX_HIERARCHIES_BETA_GG_HIERARCHY_H_ +#define BAYESMIX_HIERARCHIES_BETA_GG_HIERARCHY_H_ + +#include "base_hierarchy.h" +#include "hierarchy_id.pb.h" +#include "likelihoods/beta_likelihood.h" +#include "priors/gamma_gamma_prior.h" +#include "updaters/random_walk_updater.h" + +/** + * Beta Gamma-Gamma hierarchy for univaraite data in [0, 1] + * + * This class represents a hierarchical model where data are distributed + * according to a Beta likelihood (see the `BetaLikelihood` class for + * details). The shape and rate parameters of the likelihood have + * independent gamma priors. That is + * + * \f[ + * f(x_i \mid \alpha, \beta) &= Beta(\alpha, \beta) \\ + * \alpha &\sim Gamma(\alpha_a, \alpha_b) \\ + * \beta &\sim Gamma(\beta_a, \beta_b) + * \f] + * + * The state is composed of shape and rate. Note that this hierarchy + * is NOT conjugate, meaning that the marginal distribution is not available + * in closed form. + */ + +class BetaGGHierarchy + : public BaseHierarchy { + public: + BetaGGHierarchy() = default; + ~BetaGGHierarchy() = default; + + //! Returns the Protobuf ID associated to this class + bayesmix::HierarchyId get_id() const override { + return bayesmix::HierarchyId::BetaGG; + } + + //! Sets the default updater algorithm for this hierarchy + void set_default_updater() { + updater = std::make_shared(0.1); + } + + //! Initializes state parameters to appropriate values + void initialize_state() override { + // Get hypers + auto hypers = prior->get_hypers(); + // Initialize likelihood state + State::ShapeRate state; + state.shape = hypers.a_shape / hypers.a_rate; + state.rate = hypers.b_shape / hypers.b_rate; + like->set_state(state); + }; +}; + +#endif // BAYESMIX_HIERARCHIES_BETA_GG_HIERARCHY_H_ diff --git a/src/hierarchies/likelihoods/CMakeLists.txt b/src/hierarchies/likelihoods/CMakeLists.txt index df10e8674..15dea96c0 100644 --- a/src/hierarchies/likelihoods/CMakeLists.txt +++ b/src/hierarchies/likelihoods/CMakeLists.txt @@ -12,6 +12,8 @@ target_sources(bayesmix PUBLIC laplace_likelihood.cc fa_likelihood.h fa_likelihood.cc + beta_likelihood.h + beta_likelihood.cc ) add_subdirectory(states) diff --git a/src/hierarchies/likelihoods/beta_likelihood.cc b/src/hierarchies/likelihoods/beta_likelihood.cc new file mode 100644 index 000000000..ddd08df13 --- /dev/null +++ b/src/hierarchies/likelihoods/beta_likelihood.cc @@ -0,0 +1,22 @@ +#include "beta_likelihood.h" + +double BetaLikelihood::compute_lpdf(const Eigen::RowVectorXd &datum) const { + return stan::math::beta_lpdf(datum(0), state.shape, state.rate); +} + +void BetaLikelihood::update_sum_stats(const Eigen::RowVectorXd &datum, + bool add) { + double x = datum(0); + if (add) { + sum_logs += std::log(x); + sum_logs1m += std::log(1. - x); + } else { + sum_logs -= std::log(x); + sum_logs1m -= std::log(1. - x); + } +} + +void BetaLikelihood::clear_summary_statistics() { + sum_logs = 0; + sum_logs1m = 0; +} diff --git a/src/hierarchies/likelihoods/beta_likelihood.h b/src/hierarchies/likelihoods/beta_likelihood.h new file mode 100644 index 000000000..d99a2aa84 --- /dev/null +++ b/src/hierarchies/likelihoods/beta_likelihood.h @@ -0,0 +1,56 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_BETA_LIKELIHOOD_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_BETA_LIKELIHOOD_H_ + +#include + +#include +#include +#include + +#include "algorithm_state.pb.h" +#include "base_likelihood.h" +#include "states/includes.h" + +/** + * A univariate Beta likelihood, using the `State::ShapeRate` state. Represents + * the model: + * + * \f[ + * y_1,\dots,y_k \mid \mu, \sigma^2 \stackrel{\small\mathrm{iid}}{\sim} + * Beta(a, b), + * \f] + */ + +class BetaLikelihood + : public BaseLikelihood { + public: + BetaLikelihood() = default; + ~BetaLikelihood() = default; + bool is_multivariate() const override { return false; }; + bool is_dependent() const override { return false; }; + void clear_summary_statistics() override; + + template + T cluster_lpdf_from_unconstrained( + const Eigen::Matrix &unconstrained_params) const { + assert(unconstrained_params.size() == 2); + + T a = stan::math::positive_constrain(unconstrained_params(0)); + T b = stan::math::positive_constrain(unconstrained_params(1)); + + T out = 0.; + + return (a - 1.) * sum_logs + (b - 1.) * sum_logs1m - + card * (stan::math::lgamma(a) + stan::math::lgamma(b) - + stan::math::lgamma(a + b)); + } + + protected: + double compute_lpdf(const Eigen::RowVectorXd &datum) const override; + void update_sum_stats(const Eigen::RowVectorXd &datum, bool add) override; + + double sum_logs = 0; + double sum_logs1m = 0; +}; + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_BETA_LIKELIHOOD_H_ diff --git a/src/hierarchies/likelihoods/states/includes.h b/src/hierarchies/likelihoods/states/includes.h index f4f868c52..850bc25c8 100644 --- a/src/hierarchies/likelihoods/states/includes.h +++ b/src/hierarchies/likelihoods/states/includes.h @@ -3,6 +3,7 @@ #include "fa_state.h" #include "multi_ls_state.h" +#include "shape_rate_state.h" #include "uni_lin_reg_ls_state.h" #include "uni_ls_state.h" diff --git a/src/hierarchies/likelihoods/states/shape_rate_state.h b/src/hierarchies/likelihoods/states/shape_rate_state.h new file mode 100644 index 000000000..c1ffbcc2a --- /dev/null +++ b/src/hierarchies/likelihoods/states/shape_rate_state.h @@ -0,0 +1,88 @@ +#ifndef BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_SHAPE_RATE_STATE_H_ +#define BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_SHAPE_RATE_STATE_H_ + +#include +#include + +#include "algorithm_state.pb.h" +#include "base_state.h" +#include "src/utils/proto_utils.h" + +namespace State { + +//! Returns the constrained parametrization from the +//! unconstrained one, i.e. [in[0], exp(in[1])] +template +Eigen::Matrix shape_rate_to_constrained( + Eigen::Matrix in) { + Eigen::Matrix out(2); + out << stan::math::exp(in(0)), stan::math::exp(in(1)); + return out; +} + +//! Returns the unconstrained parametrization from the +//! constrained one, i.e. [log(in[0]), log(in[1])] +template +Eigen::Matrix shape_rate_to_unconstrained( + Eigen::Matrix in) { + Eigen::Matrix out(2); + out << stan::math::log(in(0)), stan::math::log(in(1)); + return out; +} + +//! Returns the log determinant of the jacobian of the map +//! (x, y) -> (log(x), log(y)), that is the inverse map of the +//! constrained -> unconstrained representation. +template +T shape_rate_log_det_jac(Eigen::Matrix constrained) { + T out = 0; + stan::math::positive_constrain(stan::math::log(constrained(0)), out); + stan::math::positive_constrain(stan::math::log(constrained(1)), out); + return out; +} + +//! A univariate shape-rate state +//! The unconstrained representation corresponds to (log(shape), log(rate)) +class ShapeRate : public BaseState { + public: + double shape, rate; + + using ProtoState = bayesmix::AlgorithmState::ClusterState; + + Eigen::VectorXd get_unconstrained() const override { + Eigen::VectorXd temp(2); + temp << shape, rate; + return shape_rate_to_unconstrained(temp); + } + + void set_from_unconstrained(const Eigen::VectorXd &in) override { + Eigen::VectorXd temp = shape_rate_to_constrained(in); + shape = temp(0); + rate = temp(1); + } + + void set_from_proto(const ProtoState &state_, bool update_card) override { + if (update_card) { + card = state_.cardinality(); + } + shape = state_.sr_state().shape(); + rate = state_.sr_state().rate(); + } + + ProtoState get_as_proto() const override { + ProtoState state; + state.mutable_sr_state()->set_shape(shape); + state.mutable_sr_state()->set_rate(rate); + return state; + } + + double log_det_jac() const override { + Eigen::VectorXd temp(2); + temp << shape, rate; + return shape_rate_log_det_jac(temp); + } +}; + +} // namespace State + +#endif // BAYESMIX_HIERARCHIES_LIKELIHOODS_STATES_SHAPE_RATE_STATE_H_ diff --git a/src/hierarchies/load_hierarchies.h b/src/hierarchies/load_hierarchies.h index 21ec46ae5..61c8fb18a 100644 --- a/src/hierarchies/load_hierarchies.h +++ b/src/hierarchies/load_hierarchies.h @@ -5,6 +5,7 @@ #include #include "abstract_hierarchy.h" +#include "betagg_hierarchy.h" #include "fa_hierarchy.h" #include "hierarchy_id.pb.h" #include "lapnig_hierarchy.h" @@ -43,6 +44,9 @@ __attribute__((constructor)) static void load_hierarchies() { Builder LapNIGbuilder = []() { return std::make_shared(); }; + Builder BetaGGbuilder = []() { + return std::make_shared(); + }; factory.add_builder(NNIGHierarchy().get_id(), NNIGbuilder); factory.add_builder(NNxIGHierarchy().get_id(), NNxIGbuilder); @@ -50,6 +54,7 @@ __attribute__((constructor)) static void load_hierarchies() { factory.add_builder(LinRegUniHierarchy().get_id(), LinRegUnibuilder); factory.add_builder(FAHierarchy().get_id(), FAbuilder); factory.add_builder(LapNIGHierarchy().get_id(), LapNIGbuilder); + factory.add_builder(BetaGGHierarchy().get_id(), BetaGGbuilder); } #endif // BAYESMIX_HIERARCHIES_LOAD_HIERARCHIES_H_ diff --git a/src/hierarchies/priors/CMakeLists.txt b/src/hierarchies/priors/CMakeLists.txt index d6901ee65..53ad75518 100644 --- a/src/hierarchies/priors/CMakeLists.txt +++ b/src/hierarchies/priors/CMakeLists.txt @@ -13,4 +13,6 @@ target_sources(bayesmix PUBLIC mnig_prior_model.cc fa_prior_model.h fa_prior_model.cc + gamma_gamma_prior.h + gamma_gamma_prior.cc ) diff --git a/src/hierarchies/priors/gamma_gamma_prior.cc b/src/hierarchies/priors/gamma_gamma_prior.cc new file mode 100644 index 000000000..9e6c58d20 --- /dev/null +++ b/src/hierarchies/priors/gamma_gamma_prior.cc @@ -0,0 +1,67 @@ +#include "gamma_gamma_prior.h" + +double GGPriorModel::lpdf(const google::protobuf::Message &state_) { + // Downcast state + auto &state = downcast_state(state_).sr_state(); + double target = 0.; + target += + stan::math::gamma_lpdf(state.shape(), hypers->a_shape, hypers->a_rate); + target += + stan::math::gamma_lpdf(state.rate(), hypers->b_shape, hypers->b_rate); + return target; +} + +State::ShapeRate GGPriorModel::sample(ProtoHypersPtr hier_hypers) { + // Random seed + auto &rng = bayesmix::Rng::Instance().get(); + + // Get params to use + auto params = get_hypers_proto()->gg_state(); + State::ShapeRate out; + out.shape = stan::math::gamma_rng(params.a_shape(), params.a_rate(), rng); + out.rate = stan::math::gamma_rng(params.b_shape(), params.b_rate(), rng); + return out; +} + +void GGPriorModel::update_hypers( + const std::vector &states) { + auto &rng = bayesmix::Rng::Instance().get(); + if (prior->has_fixed_values()) { + return; + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} + +void GGPriorModel::set_hypers_from_proto( + const google::protobuf::Message &hypers_) { + auto &hyperscast = downcast_hypers(hypers_).gg_state(); + hypers->a_shape = hyperscast.a_shape(); + hypers->a_rate = hyperscast.a_rate(); + hypers->b_shape = hyperscast.b_shape(); + hypers->b_rate = hyperscast.b_rate(); +} + +std::shared_ptr +GGPriorModel::get_hypers_proto() const { + bayesmix::GamGamDistribution hypers_; + hypers_.set_a_shape(hypers->a_shape); + hypers_.set_a_rate(hypers->a_rate); + hypers_.set_b_shape(hypers->b_shape); + hypers_.set_b_rate(hypers->b_rate); + auto out = std::make_shared(); + out->mutable_gg_state()->CopyFrom(hypers_); + return out; +} + +void GGPriorModel::initialize_hypers() { + if (prior->has_fixed_values()) { + // Set values + hypers->a_shape = prior->fixed_values().a_shape(); + hypers->a_rate = prior->fixed_values().a_rate(); + hypers->b_shape = prior->fixed_values().b_shape(); + hypers->b_rate = prior->fixed_values().b_rate(); + } else { + throw std::invalid_argument("Unrecognized hierarchy prior"); + } +} diff --git a/src/hierarchies/priors/gamma_gamma_prior.h b/src/hierarchies/priors/gamma_gamma_prior.h new file mode 100644 index 000000000..9ce995771 --- /dev/null +++ b/src/hierarchies/priors/gamma_gamma_prior.h @@ -0,0 +1,64 @@ +#ifndef BAYESMIX_HIERARCHIES_PRIORS_GG_PRIOR_MODEL_H_ +#define BAYESMIX_HIERARCHIES_PRIORS_GG_PRIOR_MODEL_H_ + +#include +#include +#include + +#include "base_prior_model.h" +#include "hierarchy_prior.pb.h" +#include "hyperparams.h" +#include "src/utils/rng.h" + +/* + * Prior model for `ShapeRate` states. + * This class assumes that the shape and rate are independent and given + * Gamma-distributed priors + */ +class GGPriorModel + : public BasePriorModel { + public: + using AbstractPriorModel::ProtoHypers; + using AbstractPriorModel::ProtoHypersPtr; + + GGPriorModel() = default; + ~GGPriorModel() = default; + + double lpdf(const google::protobuf::Message &state_) override; + + template + T lpdf_from_unconstrained( + const Eigen::Matrix &unconstrained_params) const { + Eigen::Matrix constrained_params = + State::shape_rate_to_constrained(unconstrained_params); + // std::cout << "constrained_params: " << constrained_params << std::endl; + T log_det_jac = State::shape_rate_log_det_jac(constrained_params); + T shape = constrained_params(0); + T rate = constrained_params(1); + T lpdf = stan::math::gamma_lpdf(shape, hypers->a_shape, hypers->a_rate) + + stan::math::gamma_lpdf(rate, hypers->a_shape, hypers->a_rate); + + return lpdf + log_det_jac; + } + + State::ShapeRate sample(ProtoHypersPtr hier_hypers = nullptr) override; + + void update_hypers(const std::vector + &states) override; + + void set_hypers_from_proto( + const google::protobuf::Message &hypers_) override; + + unsigned int get_dim() const { return dim; }; + + std::shared_ptr get_hypers_proto() + const override; + + protected: + void initialize_hypers() override; + + unsigned int dim; +}; + +#endif // BAYESMIX_HIERARCHIES_PRIORS_GG_PRIOR_MODEL_H_ diff --git a/src/hierarchies/priors/hyperparams.h b/src/hierarchies/priors/hyperparams.h index 1aca6dc4a..ccc6270c2 100644 --- a/src/hierarchies/priors/hyperparams.h +++ b/src/hierarchies/priors/hyperparams.h @@ -31,6 +31,11 @@ struct FA { unsigned int q; }; +struct GG { + double a_shape, a_rate; + double b_shape, b_rate; +}; + } // namespace Hyperparams #endif // BAYESMIX_HIERARCHIES_PRIORS_HYPERPARAMS_H_ diff --git a/src/proto/algorithm_state.proto b/src/proto/algorithm_state.proto index 02f6ba07d..ad8b4e641 100644 --- a/src/proto/algorithm_state.proto +++ b/src/proto/algorithm_state.proto @@ -26,7 +26,7 @@ message AlgorithmState { LinRegUniLSState lin_reg_uni_ls_state = 4; // State of a linear regression univariate location-scale family Vector general_state = 5; // Just a vector of doubles FAState fa_state = 6; // State of a Mixture of Factor Analysers - + SRState sr_state = 7; // State of a Mixture of Beta distributions } int32 cardinality = 3; // How many observations are in this cluster } @@ -45,6 +45,7 @@ message AlgorithmState { MultiNormalIGDistribution lin_reg_uni_state = 4; NxIGDistribution nnxig_state = 5; FAPriorDistribution fa_state = 7; + GamGamDistribution gg_state = 8; } } HierarchyHypers hierarchy_hypers = 5; // The current values of the hyperparameters of the hierarchy diff --git a/src/proto/distribution.proto b/src/proto/distribution.proto index 804813f78..785e2f3ae 100644 --- a/src/proto/distribution.proto +++ b/src/proto/distribution.proto @@ -95,3 +95,14 @@ message MultiNormalIGDistribution { double shape = 3; double scale = 4; } + +/* + * Parameters for the product of two independent Gamma distributions with different + * rates and shapes + */ +message GamGamDistribution { + double a_shape = 1; + double a_rate = 2; + double b_shape = 3; + double b_rate = 4; +} diff --git a/src/proto/hierarchy_id.proto b/src/proto/hierarchy_id.proto index 45e9b4d6d..710c2e075 100644 --- a/src/proto/hierarchy_id.proto +++ b/src/proto/hierarchy_id.proto @@ -14,4 +14,5 @@ enum HierarchyId { FA = 5; // Factor Analysers NNxIG = 6; // Normal - Normal x Inverse Gamma PythonHier = 7; // Generic python hierarchy + BetaGG = 8; // Beta - GammaGamma hierarchy } diff --git a/src/proto/hierarchy_prior.proto b/src/proto/hierarchy_prior.proto index 09558e0d5..7722e29c3 100644 --- a/src/proto/hierarchy_prior.proto +++ b/src/proto/hierarchy_prior.proto @@ -123,3 +123,12 @@ message PythonHierPrior{ Vector values = 1; // values are modified from python } } + +/* + * Prior for the parameters of the base measure in a Beta - GammaGamma hierarchy + */ +message GGPrior { + oneof prior { + GamGamDistribution fixed_values = 1; + } +} diff --git a/src/proto/ls_state.proto b/src/proto/ls_state.proto index 50028afde..6794185c5 100644 --- a/src/proto/ls_state.proto +++ b/src/proto/ls_state.proto @@ -37,3 +37,11 @@ message FAState { Matrix eta = 3; Matrix lambda = 4; } + +/* + * Parameters of a shape-rate state, used, e.g., by the BetaLikelihood + */ +message SRState { + double shape = 1; + double rate = 2; +} diff --git a/test/hierarchies.cc b/test/hierarchies.cc index 0a4a94c25..2c62c1a22 100644 --- a/test/hierarchies.cc +++ b/test/hierarchies.cc @@ -5,6 +5,7 @@ #include "algorithm_state.pb.h" #include "ls_state.pb.h" +#include "src/hierarchies/betagg_hierarchy.h" #include "src/hierarchies/fa_hierarchy.h" #include "src/hierarchies/lin_reg_uni_hierarchy.h" #include "src/hierarchies/nnig_hierarchy.h" @@ -373,3 +374,44 @@ TEST(fa_hierarchy, sample_given_data) { ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); } + +TEST(betagg_hierarchy, sample_given_data) { + auto hier = std::make_shared(); + bayesmix::GGPrior prior; + + prior.mutable_fixed_values()->set_a_rate(2.0); + prior.mutable_fixed_values()->set_a_shape(2.0); + prior.mutable_fixed_values()->set_b_rate(2.0); + prior.mutable_fixed_values()->set_b_shape(2.0); + hier->get_mutable_prior()->CopyFrom(prior); + + auto mala_updater = std::make_shared(0.0005); + hier->set_updater(mala_updater); + + Eigen::MatrixXd dataset = Eigen::MatrixXd::Ones(100, 1) * 0.1; + hier->set_dataset(&dataset); + hier->initialize(); + + for (int i = 0; i < dataset.rows(); i++) { + hier->add_datum(i, dataset.row(i)); + } + + for (int i = 0; i < 1000; i++) { + hier->sample_full_cond(); + } + + bayesmix::AlgorithmState::ClusterState out; + hier->write_state_to_proto(&out); + + double a = out.sr_state().shape(); + double b = out.sr_state().rate(); + + double mean = a / (a + b); + double var = a * b / ((a + b) * (a + b) * (a + b + 1)); + + ASSERT_LT(mean, 0.2); + ASSERT_GT(mean, 0.05); + ASSERT_LT(var, 0.1); + + // ASSERT_TRUE(clusval->DebugString() != clusval2->DebugString()); +} diff --git a/test/likelihoods.cc b/test/likelihoods.cc index 5223f4f6d..d7e116760 100644 --- a/test/likelihoods.cc +++ b/test/likelihoods.cc @@ -6,6 +6,7 @@ #include "algorithm_state.pb.h" #include "ls_state.pb.h" +#include "src/hierarchies/likelihoods/beta_likelihood.h" #include "src/hierarchies/likelihoods/laplace_likelihood.h" #include "src/hierarchies/likelihoods/multi_norm_likelihood.h" #include "src/hierarchies/likelihoods/uni_lin_reg_likelihood.h" @@ -389,3 +390,84 @@ TEST(laplace_likelihood, eval_lpdf_unconstrained) { clus_lpdf = like->cluster_lpdf_from_unconstrained(unconstrained_params); ASSERT_TRUE(std::abs(clus_lpdf - lpdf) > 1e-5); } + +TEST(beta_likelihood, set_get_state) { + // Instance + auto like = std::make_shared(); + + // Prepare buffers + bayesmix::ShapeRateState state_; + bayesmix::AlgorithmState::ClusterState set_state_; + bayesmix::AlgorithmState::ClusterState got_state_; + + // Prepare state + state_.set_shape(5.23); + state_.set_rate(1.02); + set_state_.mutable_sr_state()->CopyFrom(state_); + + // Set and get the state + like->set_state_from_proto(set_state_); + like->write_state_to_proto(&got_state_); + + // Check if they coincides + ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); +} + +TEST(beta_likelihood, eval_lpdf) { + // Instance + auto like = std::make_shared(); + + // Set state from proto + bayesmix::ShapeRateState state_; + bayesmix::AlgorithmState::ClusterState clust_state_; + state_.set_shape(5); + state_.set_rate(1); + clust_state_.mutable_sr_state()->CopyFrom(state_); + like->set_state_from_proto(clust_state_); + + // Add new datum to likelihood + Eigen::VectorXd data(3); + data << 0.5, 0.6, 0.4; + + // Compute lpdf on this grid of points + auto evals = like->lpdf_grid(data); + auto like_copy = like->clone(); + auto evals_copy = like_copy->lpdf_grid(data); + + // Check if they coincides + ASSERT_EQ(evals, evals_copy); +} + +TEST(beta_likelihood, eval_lpdf_unconstrained) { + // Instance + auto like = std::make_shared(); + + // Set state from proto + bayesmix::ShapeRateState state_; + bayesmix::AlgorithmState::ClusterState clust_state_; + double a = 5; + double b = 1; + state_.set_shape(a); + state_.set_rate(b); + Eigen::VectorXd unconstrained_params(2); + unconstrained_params << std::log(a), std::log(b); + clust_state_.mutable_sr_state()->CopyFrom(state_); + like->set_state_from_proto(clust_state_); + + // Add new datum to likelihood + Eigen::VectorXd data(3); + data << 0.5, 0.6, 0.4; + double lpdf = 0.0; + for (int i = 0; i < data.size(); ++i) { + like->add_datum(i, data.row(i)); + lpdf += like->lpdf(data.row(i)); + } + + double clus_lpdf = + like->cluster_lpdf_from_unconstrained(unconstrained_params); + ASSERT_NEAR(lpdf, clus_lpdf, 1e-12); + + unconstrained_params(0) = 4.0; + clus_lpdf = like->cluster_lpdf_from_unconstrained(unconstrained_params); + ASSERT_TRUE(std::abs(clus_lpdf - lpdf) > 1e-5); +} diff --git a/test/prior_models.cc b/test/prior_models.cc index fcf380c4a..f0c624370 100644 --- a/test/prior_models.cc +++ b/test/prior_models.cc @@ -6,6 +6,7 @@ #include "algorithm_state.pb.h" #include "hierarchy_prior.pb.h" +#include "src/hierarchies/priors/gamma_gamma_prior.h" #include "src/hierarchies/priors/mnig_prior_model.h" #include "src/hierarchies/priors/nig_prior_model.h" #include "src/hierarchies/priors/nw_prior_model.h" @@ -487,3 +488,85 @@ TEST(mnig_prior_model, sample) { ASSERT_TRUE(state1.get_as_proto().DebugString() != state2.get_as_proto().DebugString()); } + +TEST(gg_prior_model, set_get_hypers) { + // Instance + auto prior = std::make_shared(); + + // Prepare buffers + bayesmix::GamGamDistribution hypers_; + bayesmix::AlgorithmState::HierarchyHypers set_state_; + bayesmix::AlgorithmState::HierarchyHypers got_state_; + + // Prepare hypers + hypers_.set_a_shape(5.0); + hypers_.set_a_rate(1.0); + hypers_.set_b_shape(4.0); + hypers_.set_b_rate(3.0); + set_state_.mutable_gg_state()->CopyFrom(hypers_); + + // Set and get hypers + prior->set_hypers_from_proto(set_state_); + prior->write_hypers_to_proto(&got_state_); + + // Check if they coincides + ASSERT_EQ(got_state_.DebugString(), set_state_.DebugString()); +} + +TEST(gg_prior_model, fixed_values_prior) { + // Prepare buffers + bayesmix::GGPrior prior; + bayesmix::AlgorithmState::HierarchyHypers prior_out; + std::vector> prior_models; + std::vector states; + + // Set fixed value prior + prior.mutable_fixed_values()->set_a_shape(5.0); + prior.mutable_fixed_values()->set_a_rate(1.0); + prior.mutable_fixed_values()->set_b_shape(4.0); + prior.mutable_fixed_values()->set_b_rate(3.0); + + // Initialize prior model + auto prior_model = std::make_shared(); + prior_model->get_mutable_prior()->CopyFrom(prior); + prior_model->initialize(); + + // Check equality before update + prior_models.push_back(prior_model); + for (size_t i = 1; i < 4; i++) { + prior_models.push_back(prior_model->clone()); + prior_models[i]->write_hypers_to_proto(&prior_out); + ASSERT_EQ(prior.fixed_values().DebugString(), + prior_out.gg_state().DebugString()); + } + + // Check equality after update + prior_models[0]->update_hypers(states); + prior_models[0]->write_hypers_to_proto(&prior_out); + for (size_t i = 1; i < 4; i++) { + prior_models[i]->write_hypers_to_proto(&prior_out); + ASSERT_EQ(prior.fixed_values().DebugString(), + prior_out.gg_state().DebugString()); + } +} + +TEST(gg_prior_model, sample) { + // Instance + auto prior = std::make_shared(); + + // Define prior hypers + bayesmix::AlgorithmState::HierarchyHypers hypers_proto; + hypers_proto.mutable_gg_state()->set_a_shape(5.0); + hypers_proto.mutable_gg_state()->set_a_rate(1.0); + hypers_proto.mutable_gg_state()->set_b_shape(4.0); + hypers_proto.mutable_gg_state()->set_b_rate(3.0); + + // Set hypers and get sampled state as proto + prior->set_hypers_from_proto(hypers_proto); + auto state1 = prior->sample(); + auto state2 = prior->sample(); + + // Check if they coincides + ASSERT_TRUE(state1.get_as_proto().DebugString() != + state2.get_as_proto().DebugString()); +}