Skip to content

Commit

Permalink
Merge pull request #109 from epiverse-trace/refac-fit_seromodel
Browse files Browse the repository at this point in the history
 Simplify fit_seromodel() output and related refactorizations
  • Loading branch information
jpavlich authored Aug 24, 2023
2 parents 46e6531 + ba03f2b commit e3b962f
Show file tree
Hide file tree
Showing 9 changed files with 134 additions and 122 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export(extract_seromodel_summary)
export(fit_seromodel)
export(get_exposure_ages)
export(get_exposure_matrix)
export(get_foi_central_estimates)
export(get_prev_expanded)
export(get_table_rhats)
export(plot_foi)
Expand Down
2 changes: 1 addition & 1 deletion R/model_comparison.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#' }
#' @export
get_table_rhats <- function(seromodel_object) {
rhats <- bayesplot::rhat(seromodel_object$fit, "foi")
rhats <- bayesplot::rhat(seromodel_object$seromodel_fit, "foi")

if (any(is.nan(rhats))) {
rhats[which(is.nan(rhats))] <- 0
Expand Down
137 changes: 62 additions & 75 deletions R/modelling.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ run_seromodel <- function(serodata,
foi_model,
" finished running ------"))
if (print_summary){
print(t(seromodel_object$model_summary))
model_summary <- extract_seromodel_summary(seromodel_object = seromodel_object)
print(t(model_summary))
}
return(seromodel_object)
}
Expand Down Expand Up @@ -146,27 +147,18 @@ fit_seromodel <- function(serodata,
)

n_warmup <- floor(n_iters / 2)

if (foi_model == "tv_normal_log") {
f_init <- function() {
list(log_foi = rep(-3, length(exposure_ages)))
}
lower_quantile = 0.1
upper_quantile = 0.9
medianv_quantile = 0.5
}

}
else {
f_init <- function() {
list(foi = rep(0.01, length(exposure_ages)))
}
lower_quantile = 0.05
upper_quantile = 0.95
medianv_quantile = 0.5

}

fit <- rstan::sampling(
seromodel_fit <- rstan::sampling(
model,
data = stan_data,
iter = n_iters,
Expand All @@ -182,64 +174,21 @@ fit_seromodel <- function(serodata,
chain_id = 0 # https://github.com/stan-dev/rstan/issues/761#issuecomment-647029649
)

if (class(fit@sim$samples) != "NULL") {
loo_fit <- loo::loo(fit, save_psis = TRUE, "logLikelihood")
foi <- rstan::extract(fit, "foi", inc_warmup = FALSE)[[1]]
# foi <- rstan::extract(fit, "foi", inc_warmup = TRUE, permuted=FALSE)[[1]]
# generates central estimations
foi_cent_est <- data.frame(
year = exposure_years,
lower = apply(foi, 2, function(x) quantile(x, lower_quantile)),

upper = apply(foi, 2, function(x) quantile(x, upper_quantile)),

medianv = apply(foi, 2, function(x) quantile(x, medianv_quantile))
)


# generates a sample of iterations
if (n_iters >= 2000) {
foi_post_s <- dplyr::sample_n(as.data.frame(foi), size = 1000)
colnames(foi_post_s) <- exposure_years
} else {
foi_post_s <- as.data.frame(foi)
colnames(foi_post_s) <- exposure_years
}

if (class(seromodel_fit@sim$samples) != "NULL") {
seromodel_object <- list(
fit = fit,
seromodel_fit = seromodel_fit,
serodata = serodata,
stan_data = stan_data,
exposure_years = exposure_years,
exposure_ages = exposure_ages,
n_iters = n_iters,
n_thin = n_thin,
n_warmup = n_warmup,
foi_model = foi_model,
delta = delta,
m_treed = m_treed,
loo_fit = loo_fit,
foi_cent_est = foi_cent_est,
foi_post_s = foi_post_s
exposure_ages = exposure_ages
)
seromodel_object$model_summary <-
extract_seromodel_summary(seromodel_object)
} else {
loo_fit <- c(-1e10, 0)
seromodel_object <- list(
fit = "no model",
seromodel_fit = "no model",
serodata = serodata,
stan_data = stan_data,
exposure_years = exposure_years,
exposure_ages = exposure_ages,
n_iters = n_iters,
n_thin = n_thin,
n_warmup = n_warmup,
model = foi_model,
delta = delta,
m_treed = m_treed,
loo_fit = loo_fit,
model_summary = NA
exposure_ages = exposure_ages
)
}

Expand Down Expand Up @@ -287,6 +236,46 @@ get_exposure_matrix <- function(serodata) {
return(exposure_output)
}

#' Function that generates the central estimates for the fitted forced FoI
#'
#' @param seromodel_object Object containing the results of fitting a model by means of \link{run_seromodel}.
#' generated by means of \link{get_exposure_ages}.
#' @return \code{foi_central_estimates}. Central estimates for the fitted forced FoI
#' @examples
#' \dontrun{
#' data(chagas2012)
#' serodata <- prepare_serodata(chagas2012)
#' seromodel_object <- fit_seromodel(serodata = serodata,
#' foi_model = "constant")
#' foi_central_estimates <- get_foi_central_estimates(seromodel_object)
#' }
#'
#' @export
get_foi_central_estimates <- function(seromodel_object) {

if (seromodel_object$seromodel_fit@model_name == "tv_normal_log") {
lower_quantile = 0.1
upper_quantile = 0.9
medianv_quantile = 0.5
}
else {
lower_quantile = 0.05
upper_quantile = 0.95
medianv_quantile = 0.5
}
# extracts foi from stan fit
foi <- rstan::extract(seromodel_object$seromodel_fit, "foi", inc_warmup = FALSE)[[1]]
# generates central estimations
foi_central_estimates <- data.frame(
year = seromodel_object$exposure_years,
lower = apply(foi, 2, function(x) quantile(x, lower_quantile)),

upper = apply(foi, 2, function(x) quantile(x, upper_quantile)),

medianv = apply(foi, 2, function(x) quantile(x, medianv_quantile))
)
return(foi_central_estimates)
}

#' Method to extact a summary of the specified serological model object
#'
Expand Down Expand Up @@ -321,26 +310,24 @@ get_exposure_matrix <- function(serodata) {
#' }
#' @export
extract_seromodel_summary <- function(seromodel_object) {
foi_model <- seromodel_object$foi_model
serodata <- seromodel_object$serodata
#------- Loo estimates

loo_fit <- seromodel_object$loo_fit
loo_fit <- loo::loo(seromodel_object$seromodel_fit, save_psis = TRUE, "logLikelihood")
if (sum(is.na(loo_fit)) < 1) {
lll <- as.numeric((round(loo_fit$estimates[1, ], 2)))
} else {
lll <- c(-1e10, 0)
}
#-------
model_summary <- data.frame(
foi_model = foi_model,
dataset = serodata$survey[1],
country = serodata$country[1],
year = serodata$tsur[1],
test = serodata$test[1],
antibody = serodata$antibody[1],
n_sample = sum(serodata$total),
n_agec = length(serodata$age_mean_f),
n_iter = seromodel_object$n_iters,
foi_model = seromodel_object$seromodel_fit@model_name,
dataset = unique(seromodel_object$serodata$survey),
country = unique(seromodel_object$serodata$country),
year = unique(seromodel_object$serodata$tsur),
test = unique(seromodel_object$serodata$test),
antibody = unique(seromodel_object$serodata$antibody),
n_sample = sum(seromodel_object$serodata$total),
n_agec = length(seromodel_object$serodata$age_mean_f),
n_iter = seromodel_object$seromodel_fit@sim$iter,
elpd = lll[1],
se = lll[2],
converged = NA
Expand All @@ -360,7 +347,7 @@ extract_seromodel_summary <- function(seromodel_object) {
#'
#' This function computes the corresponding binomial confidence intervals for the obtained prevalence based on a fitting
#' of the Force-of-Infection \code{foi} for plotting an analysis purposes.
#' @param foi Object containing the information of the force of infection. It is obtained from \code{rstan::extract(seromodel_object$fit, "foi", inc_warmup = FALSE)[[1]]}.
#' @param foi Object containing the information of the force of infection. It is obtained from \code{rstan::extract(seromodel_object$seromodel, "foi", inc_warmup = FALSE)[[1]]}.
#' @param serodata A data frame containing the data from a seroprevalence survey. For further details refer to \link{run_seromodel}.
#' @param bin_data TBD
#' @return \code{prev_final}. The expanded prevalence data. This is used for plotting purposes in the \code{visualization} module.
Expand All @@ -370,7 +357,7 @@ extract_seromodel_summary <- function(seromodel_object) {
#' serodata <- prepare_serodata(chagas2012)
#' seromodel_object <- run_seromodel(serodata = serodata,
#' foi_model = "constant")
#' foi <- rstan::extract(seromodel_object$fit, "foi")[[1]]
#' foi <- rstan::extract(seromodel_object$seromodel_fit, "foi")[[1]]
#' get_prev_expanded <- function(foi, serodata)
#' }
#' @export
Expand Down
26 changes: 13 additions & 13 deletions R/visualisation.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ plot_seroprev <- function(serodata,
plot_seroprev_fitted <- function(seromodel_object,
size_text = 6) {

if (is.character(seromodel_object$fit) == FALSE) {
if (class(seromodel_object$fit@sim$samples) != "NULL" ) {
if (is.character(seromodel_object$seromodel_fit) == FALSE) {
if (class(seromodel_object$seromodel_fit@sim$samples) != "NULL" ) {

foi <- rstan::extract(seromodel_object$fit, "foi", inc_warmup = FALSE)[[1]]
foi <- rstan::extract(seromodel_object$seromodel_fit, "foi", inc_warmup = FALSE)[[1]]
prev_expanded <- get_prev_expanded(foi, serodata = seromodel_object$serodata, bin_data = TRUE)
prev_plot <-
ggplot2::ggplot(prev_expanded) +
Expand Down Expand Up @@ -148,14 +148,14 @@ plot_foi <- function(seromodel_object,
max_lambda = NA,
size_text = 25,
foi_sim = NULL) {
if (is.character(seromodel_object$fit) == FALSE) {
if (class(seromodel_object$fit@sim$samples) != "NULL") {
foi <- rstan::extract(seromodel_object$fit,
if (is.character(seromodel_object$seromodel_fit) == FALSE) {
if (class(seromodel_object$seromodel_fit@sim$samples) != "NULL") {
foi <- rstan::extract(seromodel_object$seromodel_fit,
"foi",
inc_warmup = FALSE)[[1]]

#-------- This bit is to get the actual length of the foi data
foi_data <- seromodel_object$foi_cent_est
foi_data <- get_foi_central_estimates(seromodel_object = seromodel_object)

#--------
foi_data$medianv[1] <- NA
Expand Down Expand Up @@ -243,8 +243,8 @@ plot_foi <- function(seromodel_object,
#' @export
plot_rhats <- function(seromodel_object,
size_text = 25) {
if (is.character(seromodel_object$fit) == FALSE) {
if (class(seromodel_object$fit@sim$samples) != "NULL") {
if (is.character(seromodel_object$seromodel_fit) == FALSE) {
if (class(seromodel_object$seromodel_fit@sim$samples) != "NULL") {
rhats <- get_table_rhats(seromodel_object)

rhats_plot <-
Expand Down Expand Up @@ -310,8 +310,8 @@ plot_seromodel <- function(seromodel_object,
max_lambda = NA,
size_text = 25,
foi_sim = NULL) {
if (is.character(seromodel_object$fit) == FALSE) {
if (class(seromodel_object$fit@sim$samples) != "NULL") {
if (is.character(seromodel_object$seromodel_fit) == FALSE) {
if (class(seromodel_object$seromodel_fit@sim$samples) != "NULL") {
prev_plot <- plot_seroprev_fitted(seromodel_object = seromodel_object,
size_text = size_text)

Expand All @@ -324,9 +324,9 @@ plot_seromodel <- function(seromodel_object,

rhats_plot <- plot_rhats(seromodel_object = seromodel_object,
size_text = size_text)

model_summary <- extract_seromodel_summary(seromodel_object = seromodel_object)
summary_table <- t(
dplyr::select(seromodel_object$model_summary,
dplyr::select(model_summary,
c('foi_model', 'dataset', 'elpd', 'se', 'converged')))
summary_plot <-
plot_info_table(summary_table, size_text = size_text)
Expand Down
28 changes: 28 additions & 0 deletions man/get_foi_central_estimates.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 8 additions & 8 deletions tests/testthat/models_serialization.R
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
library(devtools)
library(dplyr)
library(serofoi)
library(testthat)

set.seed(1234) # For reproducibility

#----- Read and prepare data
data("simdata_large_epi")
simdata <- simdata_large_epi %>% prepare_serodata()
simdata <- prepare_serodata(simdata_large_epi)
no_transm <- 0.0000000001
big_outbreak <- 1.5
foi_sim <- c(rep(no_transm, 32), rep(big_outbreak, 3), rep(no_transm, 15)) # 1 epidemics
Expand All @@ -23,11 +22,12 @@ models_list <- lapply(models_to_run,
serodata = simdata,
n_iters = 1000)

model_constant_json <- jsonlite::serializeJSON(models_list[[1]])
write_json(model_constant_json, testthat::test_path("extdata", "model_constant.json"))
saveRDS(models_list[[1]],
testthat::test_path("extdata", "model_constant.RDS"))

model_tv_normal_json <- jsonlite::serializeJSON(models_list[[2]])
write_json(model_tv_normal_json, testthat::test_path("extdata", "model_tv_normal.json"))
saveRDS(models_list[[2]],
testthat::test_path("extdata", "model_tv_normal.RDS"))

saveRDS(models_list[[3]],
testthat::test_path("extdata", "model_tv_normal_log.RDS"))

model_tv_normal_log_json <- jsonlite::serializeJSON(models_list[[3]])
write_json(model_tv_normal_json, testthat::test_path("extdata", "model_tv_normal_log.json"))
2 changes: 1 addition & 1 deletion tests/testthat/test_issue_47.R
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ test_that("issue 47", {

# Error reproduction
model_test <- run_seromodel(data_issue, foi_model = "tv_normal", print_summary = FALSE)
foi <- rstan::extract(model_test$fit, "foi", inc_warmup = FALSE)[[1]]
foi <- rstan::extract(model_test$seromodel_fit, "foi", inc_warmup = FALSE)[[1]]
age_max <- max(data_issue$age_mean_f)
prev_expanded <- get_prev_expanded(foi, serodata = data_issue)

Expand Down
6 changes: 3 additions & 3 deletions tests/testthat/test_modelling.R
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ test_that("individual models", {
n_iters = 1000,
print_summary = FALSE)

foi <- rstan::extract(model_object$fit, "foi", inc_warmup = FALSE)[[1]]
foi <- rstan::extract(model_object$seromodel_fit, "foi", inc_warmup = FALSE)[[1]]
prev_expanded <- get_prev_expanded(foi, serodata = model_object$serodata)
prev_expanded_constant <- readRDS(data_constant_path)

Expand All @@ -40,7 +40,7 @@ test_that("individual models", {
foi_model = model_name,
n_iters = 1000)

foi <- rstan::extract(model_object$fit, "foi", inc_warmup = FALSE)[[1]]
foi <- rstan::extract(model_object$seromodel_fit, "foi", inc_warmup = FALSE)[[1]]
prev_expanded <- get_prev_expanded(foi, serodata = model_object$serodata)
prev_expanded_tv_normal <- readRDS(data_tv_normal_path)
testthat::expect_equal(prev_expanded, prev_expanded_tv_normal, tolerance = TRUE)
Expand All @@ -52,7 +52,7 @@ test_that("individual models", {
foi_model = model_name,
n_iters = 1000)

foi <- rstan::extract(model_object$fit, "foi", inc_warmup = FALSE)[[1]]
foi <- rstan::extract(model_object$seromodel_fit, "foi", inc_warmup = FALSE)[[1]]
prev_expanded <- get_prev_expanded(foi, serodata = model_object$serodata)
prev_expanded_tv_normal <- readRDS(data_tv_normal_path)
testthat::expect_equal(prev_expanded, prev_expanded_tv_normal_log, tolerance = TRUE)
Expand Down
Loading

0 comments on commit e3b962f

Please sign in to comment.