diff --git a/NAMESPACE b/NAMESPACE index f5ffc57b..13150198 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -4,6 +4,7 @@ export(extract_seromodel_summary) export(fit_seromodel) export(get_exposure_ages) export(get_exposure_matrix) +export(get_foi_central_estimates) export(get_prev_expanded) export(get_table_rhats) export(plot_foi) diff --git a/R/model_comparison.R b/R/model_comparison.R index c2bb29c9..32825a8c 100644 --- a/R/model_comparison.R +++ b/R/model_comparison.R @@ -15,7 +15,7 @@ #' } #' @export get_table_rhats <- function(seromodel_object) { - rhats <- bayesplot::rhat(seromodel_object$fit, "foi") + rhats <- bayesplot::rhat(seromodel_object$seromodel_fit, "foi") if (any(is.nan(rhats))) { rhats[which(is.nan(rhats))] <- 0 diff --git a/R/modelling.R b/R/modelling.R index 1da29f19..36a25f80 100644 --- a/R/modelling.R +++ b/R/modelling.R @@ -68,7 +68,8 @@ run_seromodel <- function(serodata, foi_model, " finished running ------")) if (print_summary){ - print(t(seromodel_object$model_summary)) + model_summary <- extract_seromodel_summary(seromodel_object = seromodel_object) + print(t(model_summary)) } return(seromodel_object) } @@ -146,27 +147,18 @@ fit_seromodel <- function(serodata, ) n_warmup <- floor(n_iters / 2) - if (foi_model == "tv_normal_log") { f_init <- function() { list(log_foi = rep(-3, length(exposure_ages))) - } - lower_quantile = 0.1 - upper_quantile = 0.9 - medianv_quantile = 0.5 } - + } else { f_init <- function() { list(foi = rep(0.01, length(exposure_ages))) } - lower_quantile = 0.05 - upper_quantile = 0.95 - medianv_quantile = 0.5 - } - fit <- rstan::sampling( + seromodel_fit <- rstan::sampling( model, data = stan_data, iter = n_iters, @@ -182,64 +174,21 @@ fit_seromodel <- function(serodata, chain_id = 0 # https://github.com/stan-dev/rstan/issues/761#issuecomment-647029649 ) - if (class(fit@sim$samples) != "NULL") { - loo_fit <- loo::loo(fit, save_psis = TRUE, "logLikelihood") - foi <- rstan::extract(fit, "foi", inc_warmup = FALSE)[[1]] - # foi <- rstan::extract(fit, "foi", inc_warmup = TRUE, permuted=FALSE)[[1]] - # generates central estimations - foi_cent_est <- data.frame( - year = exposure_years, - lower = apply(foi, 2, function(x) quantile(x, lower_quantile)), - - upper = apply(foi, 2, function(x) quantile(x, upper_quantile)), - - medianv = apply(foi, 2, function(x) quantile(x, medianv_quantile)) - ) - - - # generates a sample of iterations - if (n_iters >= 2000) { - foi_post_s <- dplyr::sample_n(as.data.frame(foi), size = 1000) - colnames(foi_post_s) <- exposure_years - } else { - foi_post_s <- as.data.frame(foi) - colnames(foi_post_s) <- exposure_years - } - + if (class(seromodel_fit@sim$samples) != "NULL") { seromodel_object <- list( - fit = fit, + seromodel_fit = seromodel_fit, serodata = serodata, stan_data = stan_data, exposure_years = exposure_years, - exposure_ages = exposure_ages, - n_iters = n_iters, - n_thin = n_thin, - n_warmup = n_warmup, - foi_model = foi_model, - delta = delta, - m_treed = m_treed, - loo_fit = loo_fit, - foi_cent_est = foi_cent_est, - foi_post_s = foi_post_s + exposure_ages = exposure_ages ) - seromodel_object$model_summary <- - extract_seromodel_summary(seromodel_object) } else { - loo_fit <- c(-1e10, 0) seromodel_object <- list( - fit = "no model", + seromodel_fit = "no model", serodata = serodata, stan_data = stan_data, exposure_years = exposure_years, - exposure_ages = exposure_ages, - n_iters = n_iters, - n_thin = n_thin, - n_warmup = n_warmup, - model = foi_model, - delta = delta, - m_treed = m_treed, - loo_fit = loo_fit, - model_summary = NA + exposure_ages = exposure_ages ) } @@ -287,6 +236,46 @@ get_exposure_matrix <- function(serodata) { return(exposure_output) } +#' Function that generates the central estimates for the fitted forced FoI +#' +#' @param seromodel_object Object containing the results of fitting a model by means of \link{run_seromodel}. +#' generated by means of \link{get_exposure_ages}. +#' @return \code{foi_central_estimates}. Central estimates for the fitted forced FoI +#' @examples +#' \dontrun{ +#' data(chagas2012) +#' serodata <- prepare_serodata(chagas2012) +#' seromodel_object <- fit_seromodel(serodata = serodata, +#' foi_model = "constant") +#' foi_central_estimates <- get_foi_central_estimates(seromodel_object) +#' } +#' +#' @export +get_foi_central_estimates <- function(seromodel_object) { + + if (seromodel_object$seromodel_fit@model_name == "tv_normal_log") { + lower_quantile = 0.1 + upper_quantile = 0.9 + medianv_quantile = 0.5 + } + else { + lower_quantile = 0.05 + upper_quantile = 0.95 + medianv_quantile = 0.5 + } + # extracts foi from stan fit + foi <- rstan::extract(seromodel_object$seromodel_fit, "foi", inc_warmup = FALSE)[[1]] + # generates central estimations + foi_central_estimates <- data.frame( + year = seromodel_object$exposure_years, + lower = apply(foi, 2, function(x) quantile(x, lower_quantile)), + + upper = apply(foi, 2, function(x) quantile(x, upper_quantile)), + + medianv = apply(foi, 2, function(x) quantile(x, medianv_quantile)) + ) + return(foi_central_estimates) +} #' Method to extact a summary of the specified serological model object #' @@ -321,26 +310,24 @@ get_exposure_matrix <- function(serodata) { #' } #' @export extract_seromodel_summary <- function(seromodel_object) { - foi_model <- seromodel_object$foi_model - serodata <- seromodel_object$serodata #------- Loo estimates - - loo_fit <- seromodel_object$loo_fit + loo_fit <- loo::loo(seromodel_object$seromodel_fit, save_psis = TRUE, "logLikelihood") if (sum(is.na(loo_fit)) < 1) { lll <- as.numeric((round(loo_fit$estimates[1, ], 2))) } else { lll <- c(-1e10, 0) } + #------- model_summary <- data.frame( - foi_model = foi_model, - dataset = serodata$survey[1], - country = serodata$country[1], - year = serodata$tsur[1], - test = serodata$test[1], - antibody = serodata$antibody[1], - n_sample = sum(serodata$total), - n_agec = length(serodata$age_mean_f), - n_iter = seromodel_object$n_iters, + foi_model = seromodel_object$seromodel_fit@model_name, + dataset = unique(seromodel_object$serodata$survey), + country = unique(seromodel_object$serodata$country), + year = unique(seromodel_object$serodata$tsur), + test = unique(seromodel_object$serodata$test), + antibody = unique(seromodel_object$serodata$antibody), + n_sample = sum(seromodel_object$serodata$total), + n_agec = length(seromodel_object$serodata$age_mean_f), + n_iter = seromodel_object$seromodel_fit@sim$iter, elpd = lll[1], se = lll[2], converged = NA @@ -360,7 +347,7 @@ extract_seromodel_summary <- function(seromodel_object) { #' #' This function computes the corresponding binomial confidence intervals for the obtained prevalence based on a fitting #' of the Force-of-Infection \code{foi} for plotting an analysis purposes. -#' @param foi Object containing the information of the force of infection. It is obtained from \code{rstan::extract(seromodel_object$fit, "foi", inc_warmup = FALSE)[[1]]}. +#' @param foi Object containing the information of the force of infection. It is obtained from \code{rstan::extract(seromodel_object$seromodel, "foi", inc_warmup = FALSE)[[1]]}. #' @param serodata A data frame containing the data from a seroprevalence survey. For further details refer to \link{run_seromodel}. #' @param bin_data TBD #' @return \code{prev_final}. The expanded prevalence data. This is used for plotting purposes in the \code{visualization} module. @@ -370,7 +357,7 @@ extract_seromodel_summary <- function(seromodel_object) { #' serodata <- prepare_serodata(chagas2012) #' seromodel_object <- run_seromodel(serodata = serodata, #' foi_model = "constant") -#' foi <- rstan::extract(seromodel_object$fit, "foi")[[1]] +#' foi <- rstan::extract(seromodel_object$seromodel_fit, "foi")[[1]] #' get_prev_expanded <- function(foi, serodata) #' } #' @export diff --git a/R/visualisation.R b/R/visualisation.R index 775b06eb..2188a76f 100644 --- a/R/visualisation.R +++ b/R/visualisation.R @@ -64,10 +64,10 @@ plot_seroprev <- function(serodata, plot_seroprev_fitted <- function(seromodel_object, size_text = 6) { - if (is.character(seromodel_object$fit) == FALSE) { - if (class(seromodel_object$fit@sim$samples) != "NULL" ) { + if (is.character(seromodel_object$seromodel_fit) == FALSE) { + if (class(seromodel_object$seromodel_fit@sim$samples) != "NULL" ) { - foi <- rstan::extract(seromodel_object$fit, "foi", inc_warmup = FALSE)[[1]] + foi <- rstan::extract(seromodel_object$seromodel_fit, "foi", inc_warmup = FALSE)[[1]] prev_expanded <- get_prev_expanded(foi, serodata = seromodel_object$serodata, bin_data = TRUE) prev_plot <- ggplot2::ggplot(prev_expanded) + @@ -148,14 +148,14 @@ plot_foi <- function(seromodel_object, max_lambda = NA, size_text = 25, foi_sim = NULL) { - if (is.character(seromodel_object$fit) == FALSE) { - if (class(seromodel_object$fit@sim$samples) != "NULL") { - foi <- rstan::extract(seromodel_object$fit, + if (is.character(seromodel_object$seromodel_fit) == FALSE) { + if (class(seromodel_object$seromodel_fit@sim$samples) != "NULL") { + foi <- rstan::extract(seromodel_object$seromodel_fit, "foi", inc_warmup = FALSE)[[1]] #-------- This bit is to get the actual length of the foi data - foi_data <- seromodel_object$foi_cent_est + foi_data <- get_foi_central_estimates(seromodel_object = seromodel_object) #-------- foi_data$medianv[1] <- NA @@ -243,8 +243,8 @@ plot_foi <- function(seromodel_object, #' @export plot_rhats <- function(seromodel_object, size_text = 25) { - if (is.character(seromodel_object$fit) == FALSE) { - if (class(seromodel_object$fit@sim$samples) != "NULL") { + if (is.character(seromodel_object$seromodel_fit) == FALSE) { + if (class(seromodel_object$seromodel_fit@sim$samples) != "NULL") { rhats <- get_table_rhats(seromodel_object) rhats_plot <- @@ -310,8 +310,8 @@ plot_seromodel <- function(seromodel_object, max_lambda = NA, size_text = 25, foi_sim = NULL) { - if (is.character(seromodel_object$fit) == FALSE) { - if (class(seromodel_object$fit@sim$samples) != "NULL") { + if (is.character(seromodel_object$seromodel_fit) == FALSE) { + if (class(seromodel_object$seromodel_fit@sim$samples) != "NULL") { prev_plot <- plot_seroprev_fitted(seromodel_object = seromodel_object, size_text = size_text) @@ -324,9 +324,9 @@ plot_seromodel <- function(seromodel_object, rhats_plot <- plot_rhats(seromodel_object = seromodel_object, size_text = size_text) - + model_summary <- extract_seromodel_summary(seromodel_object = seromodel_object) summary_table <- t( - dplyr::select(seromodel_object$model_summary, + dplyr::select(model_summary, c('foi_model', 'dataset', 'elpd', 'se', 'converged'))) summary_plot <- plot_info_table(summary_table, size_text = size_text) diff --git a/man/get_foi_central_estimates.Rd b/man/get_foi_central_estimates.Rd new file mode 100644 index 00000000..d85285e8 --- /dev/null +++ b/man/get_foi_central_estimates.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/modelling.R +\name{get_foi_central_estimates} +\alias{get_foi_central_estimates} +\title{Function that generates the central estimates for the fitted forced FoI} +\usage{ +get_foi_central_estimates(seromodel_object) +} +\arguments{ +\item{seromodel_object}{Object containing the results of fitting a model by means of \link{run_seromodel}. +generated by means of \link{get_exposure_ages}.} +} +\value{ +\code{foi_central_estimates}. Central estimates for the fitted forced FoI +} +\description{ +Function that generates the central estimates for the fitted forced FoI +} +\examples{ +\dontrun{ +data(chagas2012) +serodata <- prepare_serodata(chagas2012) +seromodel_object <- fit_seromodel(serodata = serodata, + foi_model = "constant") +foi_central_estimates <- get_foi_central_estimates(seromodel_object) +} + +} diff --git a/tests/testthat/models_serialization.R b/tests/testthat/models_serialization.R index 6306fac0..0daa2b37 100644 --- a/tests/testthat/models_serialization.R +++ b/tests/testthat/models_serialization.R @@ -1,5 +1,4 @@ library(devtools) -library(dplyr) library(serofoi) library(testthat) @@ -7,7 +6,7 @@ set.seed(1234) # For reproducibility #----- Read and prepare data data("simdata_large_epi") -simdata <- simdata_large_epi %>% prepare_serodata() +simdata <- prepare_serodata(simdata_large_epi) no_transm <- 0.0000000001 big_outbreak <- 1.5 foi_sim <- c(rep(no_transm, 32), rep(big_outbreak, 3), rep(no_transm, 15)) # 1 epidemics @@ -23,11 +22,12 @@ models_list <- lapply(models_to_run, serodata = simdata, n_iters = 1000) -model_constant_json <- jsonlite::serializeJSON(models_list[[1]]) -write_json(model_constant_json, testthat::test_path("extdata", "model_constant.json")) +saveRDS(models_list[[1]], + testthat::test_path("extdata", "model_constant.RDS")) -model_tv_normal_json <- jsonlite::serializeJSON(models_list[[2]]) -write_json(model_tv_normal_json, testthat::test_path("extdata", "model_tv_normal.json")) +saveRDS(models_list[[2]], + testthat::test_path("extdata", "model_tv_normal.RDS")) + +saveRDS(models_list[[3]], + testthat::test_path("extdata", "model_tv_normal_log.RDS")) -model_tv_normal_log_json <- jsonlite::serializeJSON(models_list[[3]]) -write_json(model_tv_normal_json, testthat::test_path("extdata", "model_tv_normal_log.json")) diff --git a/tests/testthat/test_issue_47.R b/tests/testthat/test_issue_47.R index e46b2740..096b8d2d 100644 --- a/tests/testthat/test_issue_47.R +++ b/tests/testthat/test_issue_47.R @@ -12,7 +12,7 @@ test_that("issue 47", { # Error reproduction model_test <- run_seromodel(data_issue, foi_model = "tv_normal", print_summary = FALSE) - foi <- rstan::extract(model_test$fit, "foi", inc_warmup = FALSE)[[1]] + foi <- rstan::extract(model_test$seromodel_fit, "foi", inc_warmup = FALSE)[[1]] age_max <- max(data_issue$age_mean_f) prev_expanded <- get_prev_expanded(foi, serodata = data_issue) diff --git a/tests/testthat/test_modelling.R b/tests/testthat/test_modelling.R index edea5876..f8662a31 100644 --- a/tests/testthat/test_modelling.R +++ b/tests/testthat/test_modelling.R @@ -27,7 +27,7 @@ test_that("individual models", { n_iters = 1000, print_summary = FALSE) - foi <- rstan::extract(model_object$fit, "foi", inc_warmup = FALSE)[[1]] + foi <- rstan::extract(model_object$seromodel_fit, "foi", inc_warmup = FALSE)[[1]] prev_expanded <- get_prev_expanded(foi, serodata = model_object$serodata) prev_expanded_constant <- readRDS(data_constant_path) @@ -40,7 +40,7 @@ test_that("individual models", { foi_model = model_name, n_iters = 1000) - foi <- rstan::extract(model_object$fit, "foi", inc_warmup = FALSE)[[1]] + foi <- rstan::extract(model_object$seromodel_fit, "foi", inc_warmup = FALSE)[[1]] prev_expanded <- get_prev_expanded(foi, serodata = model_object$serodata) prev_expanded_tv_normal <- readRDS(data_tv_normal_path) testthat::expect_equal(prev_expanded, prev_expanded_tv_normal, tolerance = TRUE) @@ -52,7 +52,7 @@ test_that("individual models", { foi_model = model_name, n_iters = 1000) - foi <- rstan::extract(model_object$fit, "foi", inc_warmup = FALSE)[[1]] + foi <- rstan::extract(model_object$seromodel_fit, "foi", inc_warmup = FALSE)[[1]] prev_expanded <- get_prev_expanded(foi, serodata = model_object$serodata) prev_expanded_tv_normal <- readRDS(data_tv_normal_path) testthat::expect_equal(prev_expanded, prev_expanded_tv_normal_log, tolerance = TRUE) diff --git a/tests/testthat/test_visualisation.R b/tests/testthat/test_visualisation.R index 52ef09fa..896788c1 100644 --- a/tests/testthat/test_visualisation.R +++ b/tests/testthat/test_visualisation.R @@ -10,12 +10,11 @@ test_that("individual models", { set.seed(1234) # For reproducibility library(devtools) - library(dplyr) library(vdiffr) library(jsonlite) - data("simdata_large_epi") - simdata <- simdata_large_epi %>% prepare_serodata() + data(simdata_large_epi) + simdata <- prepare_serodata(simdata_large_epi) no_transm <- 0.0000000001 big_outbreak <- 1.5 foi_sim <- c(rep(no_transm, 32), rep(big_outbreak, 3), rep(no_transm, 15)) # 1 epidemics @@ -25,28 +24,25 @@ test_that("individual models", { size_text <- 6 max_lambda <- 1.55 - model_constant_json <- jsonlite::fromJSON(testthat::test_path("extdata", "model_constant.json")) - model_constant <- jsonlite::unserializeJSON(model_constant_json) - constant_plot <- plot_seromodel(model_constant, - size_text = size_text, - max_lambda = max_lambda, - foi_sim = foi_sim + model_constant <- readRDS(testthat::test_path("extdata", "model_constant.RDS")) + constant_plot <- plot_seromodel(seromodel_object = model_constant, + size_text = size_text, + max_lambda = max_lambda, + foi_sim = foi_sim ) - model_tv_normal_json <- fromJSON(testthat::test_path("extdata", "model_tv_normal.json")) - model_tv_normal <- jsonlite::unserializeJSON(model_tv_normal_json) - tv_normal_plot <- plot_seromodel(model_tv_normal, - size_text = size_text, - max_lambda = max_lambda, - foi_sim = foi_sim + model_tv_normal <- readRDS(testthat::test_path("extdata", "model_tv_normal.RDS")) + tv_normal_plot <- plot_seromodel(seromodel_object = model_tv_normal, + size_text = size_text, + max_lambda = max_lambda, + foi_sim = foi_sim ) - model_tv_normal_log_json <- fromJSON(testthat::test_path("extdata", "model_tv_normal_log.json")) - model_tv_normal_log <- jsonlite::unserializeJSON(model_tv_normal_log_json) - tv_normal_log_plot <- plot_seromodel(model_tv_normal_log, - size_text = size_text, - max_lambda = max_lambda, - foi_sim = foi_sim + model_tv_normal_log <- readRDS(testthat::test_path("extdata", "model_tv_normal_log.RDS")) + tv_normal_log_plot <- plot_seromodel(seromodel_object = model_tv_normal_log, + size_text = size_text, + max_lambda = max_lambda, + foi_sim = foi_sim ) plot_arrange <- cowplot::plot_grid(constant_plot,