diff --git a/NAMESPACE b/NAMESPACE index cd90337..82bf030 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -27,8 +27,10 @@ S3method(get_n_viruses,epiworld_model) S3method(get_name,epiworld_model) S3method(get_ndays,epiworld_model) S3method(get_param,epiworld_model) +S3method(get_params_mean,epiworld_lfmcmc) S3method(get_reproductive_number,epiworld_model) S3method(get_states,epiworld_model) +S3method(get_stats_mean,epiworld_lfmcmc) S3method(get_today_total,epiworld_model) S3method(get_transition_probability,epiworld_model) S3method(get_transmissions,epiworld_diffnet) @@ -150,9 +152,11 @@ export(get_name_virus) export(get_ndays) export(get_network) export(get_param) +export(get_params_mean) export(get_reproductive_number) export(get_state) export(get_states) +export(get_stats_mean) export(get_today_total) export(get_tool) export(get_transition_probability) diff --git a/R/LFMCMC.R b/R/LFMCMC.R index efcc8a8..b18ba81 100644 --- a/R/LFMCMC.R +++ b/R/LFMCMC.R @@ -68,6 +68,10 @@ #' set_par_names(lfmcmc_model, c("Immune recovery", "Infectiousness")) #' #' print(lfmcmc_model) +#' +#' get_stats_mean(lfmcmc_model) +#' get_params_mean(lfmcmc_model) +#' #' @export LFMCMC <- function(model) { if (!inherits(model, "epiworld_model")) @@ -205,6 +209,28 @@ set_stats_names.epiworld_lfmcmc <- function(lfmcmc, names) { invisible(lfmcmc) } +#' @rdname LFMCMC +#' @param lfmcmc LFMCMC model +#' @returns The param means for the given lfmcmc model +#' @export +get_params_mean <- function(lfmcmc) UseMethod("get_params_mean") + +#' @export +get_params_mean.epiworld_lfmcmc <- function(lfmcmc) { + get_params_mean_cpp(lfmcmc) +} + +#' @rdname LFMCMC +#' @param lfmcmc LFMCMC model +#' @returns The stats means for the given lfmcmc model +#' @export +get_stats_mean <- function(lfmcmc) UseMethod("get_stats_mean") + +#' @export +get_stats_mean.epiworld_lfmcmc <- function(lfmcmc) { + get_stats_mean_cpp(lfmcmc) +} + #' @rdname LFMCMC #' @param x LFMCMC model to print #' @param ... Ignored diff --git a/R/cpp11.R b/R/cpp11.R index 71c5468..7b58bc8 100644 --- a/R/cpp11.R +++ b/R/cpp11.R @@ -264,6 +264,14 @@ set_stats_names_cpp <- function(lfmcmc, names) { .Call(`_epiworldR_set_stats_names_cpp`, lfmcmc, names) } +get_params_mean_cpp <- function(lfmcmc) { + .Call(`_epiworldR_get_params_mean_cpp`, lfmcmc) +} + +get_stats_mean_cpp <- function(lfmcmc) { + .Call(`_epiworldR_get_stats_mean_cpp`, lfmcmc) +} + print_lfmcmc_cpp <- function(lfmcmc) { .Call(`_epiworldR_print_lfmcmc_cpp`, lfmcmc) } diff --git a/inst/tinytest/test-lfmcmc.R b/inst/tinytest/test-lfmcmc.R index aa89b48..4c47628 100644 --- a/inst/tinytest/test-lfmcmc.R +++ b/inst/tinytest/test-lfmcmc.R @@ -72,6 +72,9 @@ expect_silent(set_par_names(lfmcmc_model, c("Immune recovery", "Infectiousness") expect_stdout(print(lfmcmc_model)) +expect_equal(get_stats_mean(lfmcmc_model), c(4.45, 2.6135, 992.4365)) +expect_equal(get_params_mean(lfmcmc_model), c(11.58421, 18.96851), tolerance = 0.00001) + # Check LFMCMC using factory functions ----------------------------------------- expect_silent(use_proposal_norm_reflective(lfmcmc_model)) expect_silent(use_kernel_fun_gaussian(lfmcmc_model)) diff --git a/man/LFMCMC.Rd b/man/LFMCMC.Rd index 2bf0434..7e9204d 100644 --- a/man/LFMCMC.Rd +++ b/man/LFMCMC.Rd @@ -13,6 +13,8 @@ \alias{use_kernel_fun_gaussian} \alias{set_par_names} \alias{set_stats_names} +\alias{get_params_mean} +\alias{get_stats_mean} \alias{print.epiworld_lfmcmc} \title{Likelihood-Free Markhov Chain Monte Carlo (LFMCMC)} \usage{ @@ -38,6 +40,10 @@ set_par_names(lfmcmc, names) set_stats_names(lfmcmc, names) +get_params_mean(lfmcmc) + +get_stats_mean(lfmcmc) + \method{print}{epiworld_lfmcmc}(x, ...) } \arguments{ @@ -86,6 +92,10 @@ The lfmcmc model with the parameter names added The lfmcmc model with the stats names added +The param means for the given lfmcmc model + +The stats means for the given lfmcmc model + The lfmcmc model } \description{ @@ -155,4 +165,8 @@ set_stats_names(lfmcmc_model, get_states(model_sir)) set_par_names(lfmcmc_model, c("Immune recovery", "Infectiousness")) print(lfmcmc_model) + +get_stats_mean(lfmcmc_model) +get_params_mean(lfmcmc_model) + } diff --git a/man/epiworld-methods.Rd b/man/epiworld-methods.Rd index 049b428..f7357fd 100644 --- a/man/epiworld-methods.Rd +++ b/man/epiworld-methods.Rd @@ -271,7 +271,7 @@ get_n_tools(model_sirconn) # Returns the number of tools in the model. In get_ndays(model_sirconn) # Returns the length of the simulation in days. This # will match "ndays" within the "run" function. -today(model_sirconn) # Returns the current day of the simulation. This will +today(model_sirconn) # Returns the current day of the simulation. This will # match "get_ndays()" if run at the end of a simulation, but will differ if run # during a simulation diff --git a/src/cpp11.cpp b/src/cpp11.cpp index 3cfb062..4c6addc 100644 --- a/src/cpp11.cpp +++ b/src/cpp11.cpp @@ -468,6 +468,20 @@ extern "C" SEXP _epiworldR_set_stats_names_cpp(SEXP lfmcmc, SEXP names) { END_CPP11 } // lfmcmc.cpp +cpp11::writable::doubles get_params_mean_cpp(SEXP lfmcmc); +extern "C" SEXP _epiworldR_get_params_mean_cpp(SEXP lfmcmc) { + BEGIN_CPP11 + return cpp11::as_sexp(get_params_mean_cpp(cpp11::as_cpp>(lfmcmc))); + END_CPP11 +} +// lfmcmc.cpp +cpp11::writable::doubles get_stats_mean_cpp(SEXP lfmcmc); +extern "C" SEXP _epiworldR_get_stats_mean_cpp(SEXP lfmcmc) { + BEGIN_CPP11 + return cpp11::as_sexp(get_stats_mean_cpp(cpp11::as_cpp>(lfmcmc))); + END_CPP11 +} +// lfmcmc.cpp SEXP print_lfmcmc_cpp(SEXP lfmcmc); extern "C" SEXP _epiworldR_print_lfmcmc_cpp(SEXP lfmcmc) { BEGIN_CPP11 @@ -1068,9 +1082,11 @@ static const R_CallMethodDef CallEntries[] = { {"_epiworldR_get_ndays_cpp", (DL_FUNC) &_epiworldR_get_ndays_cpp, 1}, {"_epiworldR_get_network_cpp", (DL_FUNC) &_epiworldR_get_network_cpp, 1}, {"_epiworldR_get_param_cpp", (DL_FUNC) &_epiworldR_get_param_cpp, 2}, + {"_epiworldR_get_params_mean_cpp", (DL_FUNC) &_epiworldR_get_params_mean_cpp, 1}, {"_epiworldR_get_reproductive_number_cpp", (DL_FUNC) &_epiworldR_get_reproductive_number_cpp, 1}, {"_epiworldR_get_state_agent_cpp", (DL_FUNC) &_epiworldR_get_state_agent_cpp, 1}, {"_epiworldR_get_states_cpp", (DL_FUNC) &_epiworldR_get_states_cpp, 1}, + {"_epiworldR_get_stats_mean_cpp", (DL_FUNC) &_epiworldR_get_stats_mean_cpp, 1}, {"_epiworldR_get_today_total_cpp", (DL_FUNC) &_epiworldR_get_today_total_cpp, 1}, {"_epiworldR_get_tool_model_cpp", (DL_FUNC) &_epiworldR_get_tool_model_cpp, 2}, {"_epiworldR_get_transition_probability_cpp", (DL_FUNC) &_epiworldR_get_transition_probability_cpp, 1}, diff --git a/src/lfmcmc.cpp b/src/lfmcmc.cpp index 075d401..5b615c5 100644 --- a/src/lfmcmc.cpp +++ b/src/lfmcmc.cpp @@ -2,6 +2,7 @@ #include "cpp11/external_pointer.hpp" #include "cpp11/r_vector.hpp" #include "cpp11/sexp.hpp" +#include "cpp11/doubles.hpp" #include #include "epiworld-common.h" @@ -145,12 +146,12 @@ SEXP set_summary_fun_cpp( LFMCMC* ) -> void { - if (res.size() == 0u) - res.resize(dat.size()); - auto dat_int = cpp11::integers(dat); auto res_tmp = cpp11::integers(fun(dat_int)); + if (res.size() == 0u) + res.resize(res_tmp.size()); + std::copy(res_tmp.begin(), res_tmp.end(), res.begin()); return; @@ -225,6 +226,22 @@ SEXP set_stats_names_cpp( return lfmcmc; } +[[cpp11::register]] +cpp11::writable::doubles get_params_mean_cpp( + SEXP lfmcmc +) { + WrapLFMCMC(lfmcmc_ptr)(lfmcmc); + return cpp11::doubles(lfmcmc_ptr->get_params_mean()); +} + +[[cpp11::register]] +cpp11::writable::doubles get_stats_mean_cpp( + SEXP lfmcmc +) { + WrapLFMCMC(lfmcmc_ptr)(lfmcmc); + return cpp11::doubles(lfmcmc_ptr->get_stats_mean()); +} + [[cpp11::register]] SEXP print_lfmcmc_cpp( SEXP lfmcmc