Skip to content

Commit

Permalink
feat: add function get_foi_central_estimates() to the modelling mod…
Browse files Browse the repository at this point in the history
…ule.

This change is meant to simplify `fit_seromodel()`. This commit also changes the name of the stanfit object in the output of `fit_seromodel()` from `fit` to `seromodel_fit`.
  • Loading branch information
ntorresd committed Aug 15, 2023
1 parent b528543 commit a02b92b
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 50 deletions.
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export(extract_seromodel_summary)
export(fit_seromodel)
export(get_exposure_ages)
export(get_exposure_matrix)
export(get_foi_central_estimates)
export(get_prev_expanded)
export(get_table_rhats)
export(plot_foi)
Expand Down
85 changes: 46 additions & 39 deletions R/modelling.R
Original file line number Diff line number Diff line change
Expand Up @@ -146,27 +146,18 @@ fit_seromodel <- function(serodata,
)

n_warmup <- floor(n_iters / 2)

if (foi_model == "tv_normal_log") {
f_init <- function() {
list(log_foi = rep(-3, length(exposure_ages)))
}
lower_quantile = 0.1
upper_quantile = 0.9
medianv_quantile = 0.5
}

}
else {
f_init <- function() {
list(foi = rep(0.01, length(exposure_ages)))
}
lower_quantile = 0.05
upper_quantile = 0.95
medianv_quantile = 0.5

}

fit <- rstan::sampling(
seromodel_fit <- rstan::sampling(
model,
data = stan_data,
iter = n_iters,
Expand All @@ -182,32 +173,9 @@ fit_seromodel <- function(serodata,
chain_id = 0 # https://github.com/stan-dev/rstan/issues/761#issuecomment-647029649
)

if (class(fit@sim$samples) != "NULL") {
loo_fit <- loo::loo(fit, save_psis = TRUE, "logLikelihood")
foi <- rstan::extract(fit, "foi", inc_warmup = FALSE)[[1]]
# foi <- rstan::extract(fit, "foi", inc_warmup = TRUE, permuted=FALSE)[[1]]
# generates central estimations
foi_cent_est <- data.frame(
year = exposure_years,
lower = apply(foi, 2, function(x) quantile(x, lower_quantile)),

upper = apply(foi, 2, function(x) quantile(x, upper_quantile)),

medianv = apply(foi, 2, function(x) quantile(x, medianv_quantile))
)


# generates a sample of iterations
if (n_iters >= 2000) {
foi_post_s <- dplyr::sample_n(as.data.frame(foi), size = 1000)
colnames(foi_post_s) <- exposure_years
} else {
foi_post_s <- as.data.frame(foi)
colnames(foi_post_s) <- exposure_years
}

if (class(seromodel_fit@sim$samples) != "NULL") {
seromodel_object <- list(
fit = fit,
seromodel_fit = seromodel_fit,
serodata = serodata,
stan_data = stan_data,
exposure_years = exposure_years,
Expand All @@ -225,9 +193,8 @@ fit_seromodel <- function(serodata,
seromodel_object$model_summary <-
extract_seromodel_summary(seromodel_object)
} else {
loo_fit <- c(-1e10, 0)
seromodel_object <- list(
fit = "no model",
seromodel_fit = "no model",
serodata = serodata,
stan_data = stan_data,
exposure_years = exposure_years,
Expand Down Expand Up @@ -287,6 +254,46 @@ get_exposure_matrix <- function(serodata) {
return(exposure_output)
}

#' Function that generates the central estimates for the fitted forced FoI
#'
#' @param seromodel_object Object containing the results of fitting a model by means of \link{run_seromodel}.
#' generated by means of \link{get_exposure_ages}.
#' @return \code{foi_central_estimates}. Central estimates for the fitted forced FoI
#' @examples
#' \dontrun{
#' data(chagas2012)
#' serodata <- prepare_serodata(chagas2012)
#' seromodel_object <- fit_seromodel(serodata = serodata,
#' foi_model = "constant")
#' foi_central_estimates <- get_foi_central_estimates(seromodel_object)
#' }
#'
#' @export
get_foi_central_estimates <- function(seromodel_object) {

if (seromodel_object$seromodel_fit@model_name == "tv_normal_log") {
lower_quantile = 0.1
upper_quantile = 0.9
medianv_quantile = 0.5
}
else {
lower_quantile = 0.05
upper_quantile = 0.95
medianv_quantile = 0.5
}
# extracts foi from stan fit
foi <- rstan::extract(seromodel_object$seromodel_fit, "foi", inc_warmup = FALSE)[[1]]
# generates central estimations
foi_central_estimates <- data.frame(
year = seromodel_object$exposure_years,
lower = apply(foi, 2, function(x) quantile(x, lower_quantile)),

upper = apply(foi, 2, function(x) quantile(x, upper_quantile)),

medianv = apply(foi, 2, function(x) quantile(x, medianv_quantile))
)
return(foi_central_estimates)
}

#' Method to extact a summary of the specified serological model object
#'
Expand Down Expand Up @@ -370,7 +377,7 @@ extract_seromodel_summary <- function(seromodel_object) {
#' serodata <- prepare_serodata(chagas2012)
#' seromodel_object <- run_seromodel(serodata = serodata,
#' foi_model = "constant")
#' foi <- rstan::extract(seromodel_object$fit, "foi")[[1]]
#' foi <- rstan::extract(seromodel_object$seromodel_fit, "foi")[[1]]
#' get_prev_expanded <- function(foi, serodata)
#' }
#' @export
Expand Down
22 changes: 11 additions & 11 deletions R/visualisation.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ plot_seroprev <- function(serodata,
plot_seroprev_fitted <- function(seromodel_object,
size_text = 6) {

if (is.character(seromodel_object$fit) == FALSE) {
if (class(seromodel_object$fit@sim$samples) != "NULL" ) {
if (is.character(seromodel_object$seromodel_fit) == FALSE) {
if (class(seromodel_object$seromodel_fit@sim$samples) != "NULL" ) {

foi <- rstan::extract(seromodel_object$fit, "foi", inc_warmup = FALSE)[[1]]
foi <- rstan::extract(seromodel_object$seromodel_fit, "foi", inc_warmup = FALSE)[[1]]
prev_expanded <- get_prev_expanded(foi, serodata = seromodel_object$serodata, bin_data = TRUE)
prev_plot <-
ggplot2::ggplot(prev_expanded) +
Expand Down Expand Up @@ -148,14 +148,14 @@ plot_foi <- function(seromodel_object,
max_lambda = NA,
size_text = 25,
foi_sim = NULL) {
if (is.character(seromodel_object$fit) == FALSE) {
if (class(seromodel_object$fit@sim$samples) != "NULL") {
foi <- rstan::extract(seromodel_object$fit,
if (is.character(seromodel_object$seromodel_fit) == FALSE) {
if (class(seromodel_object$seromodel_fit@sim$samples) != "NULL") {
foi <- rstan::extract(seromodel_object$seromodel_fit,
"foi",
inc_warmup = FALSE)[[1]]

#-------- This bit is to get the actual length of the foi data
foi_data <- seromodel_object$foi_cent_est
foi_data <- get_foi_central_estimates(seromodel_object = seromodel_object)

#--------
foi_data$medianv[1] <- NA
Expand Down Expand Up @@ -243,8 +243,8 @@ plot_foi <- function(seromodel_object,
#' @export
plot_rhats <- function(seromodel_object,
size_text = 25) {
if (is.character(seromodel_object$fit) == FALSE) {
if (class(seromodel_object$fit@sim$samples) != "NULL") {
if (is.character(seromodel_object$seromodel_fit) == FALSE) {
if (class(seromodel_object$seromodel_fit@sim$samples) != "NULL") {
rhats <- get_table_rhats(seromodel_object)

rhats_plot <-
Expand Down Expand Up @@ -310,8 +310,8 @@ plot_seromodel <- function(seromodel_object,
max_lambda = NA,
size_text = 25,
foi_sim = NULL) {
if (is.character(seromodel_object$fit) == FALSE) {
if (class(seromodel_object$fit@sim$samples) != "NULL") {
if (is.character(seromodel_object$seromodel_fit) == FALSE) {
if (class(seromodel_object$seromodel_fit@sim$samples) != "NULL") {
prev_plot <- plot_seroprev_fitted(seromodel_object = seromodel_object,
size_text = size_text)

Expand Down

0 comments on commit a02b92b

Please sign in to comment.