From a02b92b6ecf4d11bd22c20abf5cd8c96977da253 Mon Sep 17 00:00:00 2001 From: ntorresd Date: Tue, 15 Aug 2023 09:05:57 -0500 Subject: [PATCH] feat: add function `get_foi_central_estimates()` to the modelling module. This change is meant to simplify `fit_seromodel()`. This commit also changes the name of the stanfit object in the output of `fit_seromodel()` from `fit` to `seromodel_fit`. --- NAMESPACE | 1 + R/modelling.R | 85 +++++++++++++++++++++++++---------------------- R/visualisation.R | 22 ++++++------ 3 files changed, 58 insertions(+), 50 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index f5ffc57b..13150198 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -4,6 +4,7 @@ export(extract_seromodel_summary) export(fit_seromodel) export(get_exposure_ages) export(get_exposure_matrix) +export(get_foi_central_estimates) export(get_prev_expanded) export(get_table_rhats) export(plot_foi) diff --git a/R/modelling.R b/R/modelling.R index 1da29f19..d3d3f129 100644 --- a/R/modelling.R +++ b/R/modelling.R @@ -146,27 +146,18 @@ fit_seromodel <- function(serodata, ) n_warmup <- floor(n_iters / 2) - if (foi_model == "tv_normal_log") { f_init <- function() { list(log_foi = rep(-3, length(exposure_ages))) - } - lower_quantile = 0.1 - upper_quantile = 0.9 - medianv_quantile = 0.5 } - + } else { f_init <- function() { list(foi = rep(0.01, length(exposure_ages))) } - lower_quantile = 0.05 - upper_quantile = 0.95 - medianv_quantile = 0.5 - } - fit <- rstan::sampling( + seromodel_fit <- rstan::sampling( model, data = stan_data, iter = n_iters, @@ -182,32 +173,9 @@ fit_seromodel <- function(serodata, chain_id = 0 # https://github.com/stan-dev/rstan/issues/761#issuecomment-647029649 ) - if (class(fit@sim$samples) != "NULL") { - loo_fit <- loo::loo(fit, save_psis = TRUE, "logLikelihood") - foi <- rstan::extract(fit, "foi", inc_warmup = FALSE)[[1]] - # foi <- rstan::extract(fit, "foi", inc_warmup = TRUE, permuted=FALSE)[[1]] - # generates central estimations - foi_cent_est <- data.frame( - year = exposure_years, - lower = apply(foi, 2, function(x) quantile(x, lower_quantile)), - - upper = apply(foi, 2, function(x) quantile(x, upper_quantile)), - - medianv = apply(foi, 2, function(x) quantile(x, medianv_quantile)) - ) - - - # generates a sample of iterations - if (n_iters >= 2000) { - foi_post_s <- dplyr::sample_n(as.data.frame(foi), size = 1000) - colnames(foi_post_s) <- exposure_years - } else { - foi_post_s <- as.data.frame(foi) - colnames(foi_post_s) <- exposure_years - } - + if (class(seromodel_fit@sim$samples) != "NULL") { seromodel_object <- list( - fit = fit, + seromodel_fit = seromodel_fit, serodata = serodata, stan_data = stan_data, exposure_years = exposure_years, @@ -225,9 +193,8 @@ fit_seromodel <- function(serodata, seromodel_object$model_summary <- extract_seromodel_summary(seromodel_object) } else { - loo_fit <- c(-1e10, 0) seromodel_object <- list( - fit = "no model", + seromodel_fit = "no model", serodata = serodata, stan_data = stan_data, exposure_years = exposure_years, @@ -287,6 +254,46 @@ get_exposure_matrix <- function(serodata) { return(exposure_output) } +#' Function that generates the central estimates for the fitted forced FoI +#' +#' @param seromodel_object Object containing the results of fitting a model by means of \link{run_seromodel}. +#' generated by means of \link{get_exposure_ages}. +#' @return \code{foi_central_estimates}. Central estimates for the fitted forced FoI +#' @examples +#' \dontrun{ +#' data(chagas2012) +#' serodata <- prepare_serodata(chagas2012) +#' seromodel_object <- fit_seromodel(serodata = serodata, +#' foi_model = "constant") +#' foi_central_estimates <- get_foi_central_estimates(seromodel_object) +#' } +#' +#' @export +get_foi_central_estimates <- function(seromodel_object) { + + if (seromodel_object$seromodel_fit@model_name == "tv_normal_log") { + lower_quantile = 0.1 + upper_quantile = 0.9 + medianv_quantile = 0.5 + } + else { + lower_quantile = 0.05 + upper_quantile = 0.95 + medianv_quantile = 0.5 + } + # extracts foi from stan fit + foi <- rstan::extract(seromodel_object$seromodel_fit, "foi", inc_warmup = FALSE)[[1]] + # generates central estimations + foi_central_estimates <- data.frame( + year = seromodel_object$exposure_years, + lower = apply(foi, 2, function(x) quantile(x, lower_quantile)), + + upper = apply(foi, 2, function(x) quantile(x, upper_quantile)), + + medianv = apply(foi, 2, function(x) quantile(x, medianv_quantile)) + ) + return(foi_central_estimates) +} #' Method to extact a summary of the specified serological model object #' @@ -370,7 +377,7 @@ extract_seromodel_summary <- function(seromodel_object) { #' serodata <- prepare_serodata(chagas2012) #' seromodel_object <- run_seromodel(serodata = serodata, #' foi_model = "constant") -#' foi <- rstan::extract(seromodel_object$fit, "foi")[[1]] +#' foi <- rstan::extract(seromodel_object$seromodel_fit, "foi")[[1]] #' get_prev_expanded <- function(foi, serodata) #' } #' @export diff --git a/R/visualisation.R b/R/visualisation.R index 775b06eb..65d23866 100644 --- a/R/visualisation.R +++ b/R/visualisation.R @@ -64,10 +64,10 @@ plot_seroprev <- function(serodata, plot_seroprev_fitted <- function(seromodel_object, size_text = 6) { - if (is.character(seromodel_object$fit) == FALSE) { - if (class(seromodel_object$fit@sim$samples) != "NULL" ) { + if (is.character(seromodel_object$seromodel_fit) == FALSE) { + if (class(seromodel_object$seromodel_fit@sim$samples) != "NULL" ) { - foi <- rstan::extract(seromodel_object$fit, "foi", inc_warmup = FALSE)[[1]] + foi <- rstan::extract(seromodel_object$seromodel_fit, "foi", inc_warmup = FALSE)[[1]] prev_expanded <- get_prev_expanded(foi, serodata = seromodel_object$serodata, bin_data = TRUE) prev_plot <- ggplot2::ggplot(prev_expanded) + @@ -148,14 +148,14 @@ plot_foi <- function(seromodel_object, max_lambda = NA, size_text = 25, foi_sim = NULL) { - if (is.character(seromodel_object$fit) == FALSE) { - if (class(seromodel_object$fit@sim$samples) != "NULL") { - foi <- rstan::extract(seromodel_object$fit, + if (is.character(seromodel_object$seromodel_fit) == FALSE) { + if (class(seromodel_object$seromodel_fit@sim$samples) != "NULL") { + foi <- rstan::extract(seromodel_object$seromodel_fit, "foi", inc_warmup = FALSE)[[1]] #-------- This bit is to get the actual length of the foi data - foi_data <- seromodel_object$foi_cent_est + foi_data <- get_foi_central_estimates(seromodel_object = seromodel_object) #-------- foi_data$medianv[1] <- NA @@ -243,8 +243,8 @@ plot_foi <- function(seromodel_object, #' @export plot_rhats <- function(seromodel_object, size_text = 25) { - if (is.character(seromodel_object$fit) == FALSE) { - if (class(seromodel_object$fit@sim$samples) != "NULL") { + if (is.character(seromodel_object$seromodel_fit) == FALSE) { + if (class(seromodel_object$seromodel_fit@sim$samples) != "NULL") { rhats <- get_table_rhats(seromodel_object) rhats_plot <- @@ -310,8 +310,8 @@ plot_seromodel <- function(seromodel_object, max_lambda = NA, size_text = 25, foi_sim = NULL) { - if (is.character(seromodel_object$fit) == FALSE) { - if (class(seromodel_object$fit@sim$samples) != "NULL") { + if (is.character(seromodel_object$seromodel_fit) == FALSE) { + if (class(seromodel_object$seromodel_fit@sim$samples) != "NULL") { prev_plot <- plot_seroprev_fitted(seromodel_object = seromodel_object, size_text = size_text)