Skip to content

Commit

Permalink
Implements getter functions for LFMCMC (#49)
Browse files Browse the repository at this point in the history
* Fix sizing bug in set_summary_fun_cpp that causes res to be too large when the summary fun output was smaller than dat

* Add get_params_mean_cpp and get_stats_mean_cpp to lfmcmc.cpp

* Add get_params_mean and get_stats_mean in LFMCMC.R

* Add tests for get_stats_mean and get_params_mean
  • Loading branch information
apulsipher authored Nov 19, 2024
1 parent bf4b679 commit 21d89c7
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 4 deletions.
4 changes: 4 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,10 @@ S3method(get_n_viruses,epiworld_model)
S3method(get_name,epiworld_model)
S3method(get_ndays,epiworld_model)
S3method(get_param,epiworld_model)
S3method(get_params_mean,epiworld_lfmcmc)
S3method(get_reproductive_number,epiworld_model)
S3method(get_states,epiworld_model)
S3method(get_stats_mean,epiworld_lfmcmc)
S3method(get_today_total,epiworld_model)
S3method(get_transition_probability,epiworld_model)
S3method(get_transmissions,epiworld_diffnet)
Expand Down Expand Up @@ -150,9 +152,11 @@ export(get_name_virus)
export(get_ndays)
export(get_network)
export(get_param)
export(get_params_mean)
export(get_reproductive_number)
export(get_state)
export(get_states)
export(get_stats_mean)
export(get_today_total)
export(get_tool)
export(get_transition_probability)
Expand Down
26 changes: 26 additions & 0 deletions R/LFMCMC.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,10 @@
#' set_par_names(lfmcmc_model, c("Immune recovery", "Infectiousness"))
#'
#' print(lfmcmc_model)
#'
#' get_stats_mean(lfmcmc_model)
#' get_params_mean(lfmcmc_model)
#'
#' @export
LFMCMC <- function(model) {
if (!inherits(model, "epiworld_model"))
Expand Down Expand Up @@ -205,6 +209,28 @@ set_stats_names.epiworld_lfmcmc <- function(lfmcmc, names) {
invisible(lfmcmc)
}

#' @rdname LFMCMC
#' @param lfmcmc LFMCMC model
#' @returns The param means for the given lfmcmc model
#' @export
get_params_mean <- function(lfmcmc) UseMethod("get_params_mean")

#' @export
get_params_mean.epiworld_lfmcmc <- function(lfmcmc) {
get_params_mean_cpp(lfmcmc)
}

#' @rdname LFMCMC
#' @param lfmcmc LFMCMC model
#' @returns The stats means for the given lfmcmc model
#' @export
get_stats_mean <- function(lfmcmc) UseMethod("get_stats_mean")

#' @export
get_stats_mean.epiworld_lfmcmc <- function(lfmcmc) {
get_stats_mean_cpp(lfmcmc)
}

#' @rdname LFMCMC
#' @param x LFMCMC model to print
#' @param ... Ignored
Expand Down
8 changes: 8 additions & 0 deletions R/cpp11.R
Original file line number Diff line number Diff line change
Expand Up @@ -264,6 +264,14 @@ set_stats_names_cpp <- function(lfmcmc, names) {
.Call(`_epiworldR_set_stats_names_cpp`, lfmcmc, names)
}

get_params_mean_cpp <- function(lfmcmc) {
.Call(`_epiworldR_get_params_mean_cpp`, lfmcmc)
}

get_stats_mean_cpp <- function(lfmcmc) {
.Call(`_epiworldR_get_stats_mean_cpp`, lfmcmc)
}

print_lfmcmc_cpp <- function(lfmcmc) {
.Call(`_epiworldR_print_lfmcmc_cpp`, lfmcmc)
}
Expand Down
3 changes: 3 additions & 0 deletions inst/tinytest/test-lfmcmc.R
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,9 @@ expect_silent(set_par_names(lfmcmc_model, c("Immune recovery", "Infectiousness")

expect_stdout(print(lfmcmc_model))

expect_equal(get_stats_mean(lfmcmc_model), c(4.45, 2.6135, 992.4365))
expect_equal(get_params_mean(lfmcmc_model), c(11.58421, 18.96851), tolerance = 0.00001)

# Check LFMCMC using factory functions -----------------------------------------
expect_silent(use_proposal_norm_reflective(lfmcmc_model))
expect_silent(use_kernel_fun_gaussian(lfmcmc_model))
Expand Down
14 changes: 14 additions & 0 deletions man/LFMCMC.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/epiworld-methods.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions src/cpp11.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -468,6 +468,20 @@ extern "C" SEXP _epiworldR_set_stats_names_cpp(SEXP lfmcmc, SEXP names) {
END_CPP11
}
// lfmcmc.cpp
cpp11::writable::doubles get_params_mean_cpp(SEXP lfmcmc);
extern "C" SEXP _epiworldR_get_params_mean_cpp(SEXP lfmcmc) {
BEGIN_CPP11
return cpp11::as_sexp(get_params_mean_cpp(cpp11::as_cpp<cpp11::decay_t<SEXP>>(lfmcmc)));
END_CPP11
}
// lfmcmc.cpp
cpp11::writable::doubles get_stats_mean_cpp(SEXP lfmcmc);
extern "C" SEXP _epiworldR_get_stats_mean_cpp(SEXP lfmcmc) {
BEGIN_CPP11
return cpp11::as_sexp(get_stats_mean_cpp(cpp11::as_cpp<cpp11::decay_t<SEXP>>(lfmcmc)));
END_CPP11
}
// lfmcmc.cpp
SEXP print_lfmcmc_cpp(SEXP lfmcmc);
extern "C" SEXP _epiworldR_print_lfmcmc_cpp(SEXP lfmcmc) {
BEGIN_CPP11
Expand Down Expand Up @@ -1068,9 +1082,11 @@ static const R_CallMethodDef CallEntries[] = {
{"_epiworldR_get_ndays_cpp", (DL_FUNC) &_epiworldR_get_ndays_cpp, 1},
{"_epiworldR_get_network_cpp", (DL_FUNC) &_epiworldR_get_network_cpp, 1},
{"_epiworldR_get_param_cpp", (DL_FUNC) &_epiworldR_get_param_cpp, 2},
{"_epiworldR_get_params_mean_cpp", (DL_FUNC) &_epiworldR_get_params_mean_cpp, 1},
{"_epiworldR_get_reproductive_number_cpp", (DL_FUNC) &_epiworldR_get_reproductive_number_cpp, 1},
{"_epiworldR_get_state_agent_cpp", (DL_FUNC) &_epiworldR_get_state_agent_cpp, 1},
{"_epiworldR_get_states_cpp", (DL_FUNC) &_epiworldR_get_states_cpp, 1},
{"_epiworldR_get_stats_mean_cpp", (DL_FUNC) &_epiworldR_get_stats_mean_cpp, 1},
{"_epiworldR_get_today_total_cpp", (DL_FUNC) &_epiworldR_get_today_total_cpp, 1},
{"_epiworldR_get_tool_model_cpp", (DL_FUNC) &_epiworldR_get_tool_model_cpp, 2},
{"_epiworldR_get_transition_probability_cpp", (DL_FUNC) &_epiworldR_get_transition_probability_cpp, 1},
Expand Down
23 changes: 20 additions & 3 deletions src/lfmcmc.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#include "cpp11/external_pointer.hpp"
#include "cpp11/r_vector.hpp"
#include "cpp11/sexp.hpp"
#include "cpp11/doubles.hpp"
#include <iostream>

#include "epiworld-common.h"
Expand Down Expand Up @@ -145,12 +146,12 @@ SEXP set_summary_fun_cpp(
LFMCMC<TData_default>*
) -> void {

if (res.size() == 0u)
res.resize(dat.size());

auto dat_int = cpp11::integers(dat);
auto res_tmp = cpp11::integers(fun(dat_int));

if (res.size() == 0u)
res.resize(res_tmp.size());

std::copy(res_tmp.begin(), res_tmp.end(), res.begin());

return;
Expand Down Expand Up @@ -225,6 +226,22 @@ SEXP set_stats_names_cpp(
return lfmcmc;
}

[[cpp11::register]]
cpp11::writable::doubles get_params_mean_cpp(
SEXP lfmcmc
) {
WrapLFMCMC(lfmcmc_ptr)(lfmcmc);
return cpp11::doubles(lfmcmc_ptr->get_params_mean());
}

[[cpp11::register]]
cpp11::writable::doubles get_stats_mean_cpp(
SEXP lfmcmc
) {
WrapLFMCMC(lfmcmc_ptr)(lfmcmc);
return cpp11::doubles(lfmcmc_ptr->get_stats_mean());
}

[[cpp11::register]]
SEXP print_lfmcmc_cpp(
SEXP lfmcmc
Expand Down

0 comments on commit 21d89c7

Please sign in to comment.