Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify fit_seromodel() output and related refactorizations #108

Closed
wants to merge 8 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ export(extract_seromodel_summary)
export(fit_seromodel)
export(get_exposure_ages)
export(get_exposure_matrix)
export(get_foi_central_estimates)
export(get_prev_expanded)
export(get_table_rhats)
export(plot_foi)
Expand Down
2 changes: 1 addition & 1 deletion R/model_comparison.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
#' }
#' @export
get_table_rhats <- function(seromodel_object) {
rhats <- bayesplot::rhat(seromodel_object$fit, "foi")
rhats <- bayesplot::rhat(seromodel_object$seromodel_fit, "foi")

if (any(is.nan(rhats))) {
rhats[which(is.nan(rhats))] <- 0
Expand Down
137 changes: 62 additions & 75 deletions R/modelling.R
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,8 @@ run_seromodel <- function(serodata,
foi_model,
" finished running ------"))
if (print_summary){
print(t(seromodel_object$model_summary))
model_summary <- extract_seromodel_summary(seromodel_object = seromodel_object)
print(t(model_summary))
}
return(seromodel_object)
}
Expand Down Expand Up @@ -146,27 +147,18 @@ fit_seromodel <- function(serodata,
)

n_warmup <- floor(n_iters / 2)

if (foi_model == "tv_normal_log") {
f_init <- function() {
list(log_foi = rep(-3, length(exposure_ages)))
}
lower_quantile = 0.1
upper_quantile = 0.9
medianv_quantile = 0.5
}

}
else {
f_init <- function() {
list(foi = rep(0.01, length(exposure_ages)))
}
lower_quantile = 0.05
upper_quantile = 0.95
medianv_quantile = 0.5

}

fit <- rstan::sampling(
seromodel_fit <- rstan::sampling(
model,
data = stan_data,
iter = n_iters,
Expand All @@ -182,64 +174,21 @@ fit_seromodel <- function(serodata,
chain_id = 0 # https://github.com/stan-dev/rstan/issues/761#issuecomment-647029649
)

if (class(fit@sim$samples) != "NULL") {
loo_fit <- loo::loo(fit, save_psis = TRUE, "logLikelihood")
foi <- rstan::extract(fit, "foi", inc_warmup = FALSE)[[1]]
# foi <- rstan::extract(fit, "foi", inc_warmup = TRUE, permuted=FALSE)[[1]]
# generates central estimations
foi_cent_est <- data.frame(
year = exposure_years,
lower = apply(foi, 2, function(x) quantile(x, lower_quantile)),

upper = apply(foi, 2, function(x) quantile(x, upper_quantile)),

medianv = apply(foi, 2, function(x) quantile(x, medianv_quantile))
)


# generates a sample of iterations
if (n_iters >= 2000) {
foi_post_s <- dplyr::sample_n(as.data.frame(foi), size = 1000)
colnames(foi_post_s) <- exposure_years
} else {
foi_post_s <- as.data.frame(foi)
colnames(foi_post_s) <- exposure_years
}

if (class(seromodel_fit@sim$samples) != "NULL") {
seromodel_object <- list(
fit = fit,
seromodel_fit = seromodel_fit,
serodata = serodata,
stan_data = stan_data,
exposure_years = exposure_years,
exposure_ages = exposure_ages,
n_iters = n_iters,
n_thin = n_thin,
n_warmup = n_warmup,
foi_model = foi_model,
delta = delta,
m_treed = m_treed,
loo_fit = loo_fit,
foi_cent_est = foi_cent_est,
foi_post_s = foi_post_s
exposure_ages = exposure_ages
)
seromodel_object$model_summary <-
extract_seromodel_summary(seromodel_object)
} else {
loo_fit <- c(-1e10, 0)
seromodel_object <- list(
fit = "no model",
seromodel_fit = "no model",
serodata = serodata,
stan_data = stan_data,
exposure_years = exposure_years,
exposure_ages = exposure_ages,
n_iters = n_iters,
n_thin = n_thin,
n_warmup = n_warmup,
model = foi_model,
delta = delta,
m_treed = m_treed,
loo_fit = loo_fit,
model_summary = NA
exposure_ages = exposure_ages
)
}

Expand Down Expand Up @@ -287,6 +236,46 @@ get_exposure_matrix <- function(serodata) {
return(exposure_output)
}

#' Function that generates the central estimates for the fitted forced FoI
#'
#' @param seromodel_object Object containing the results of fitting a model by means of \link{run_seromodel}.
#' generated by means of \link{get_exposure_ages}.
#' @return \code{foi_central_estimates}. Central estimates for the fitted forced FoI
#' @examples
#' \dontrun{
#' data(chagas2012)
#' serodata <- prepare_serodata(chagas2012)
#' seromodel_object <- fit_seromodel(serodata = serodata,
#' foi_model = "constant")
#' foi_central_estimates <- get_foi_central_estimates(seromodel_object)
#' }
#'
#' @export
get_foi_central_estimates <- function(seromodel_object) {

if (seromodel_object$seromodel_fit@model_name == "tv_normal_log") {
lower_quantile = 0.1
upper_quantile = 0.9
medianv_quantile = 0.5
}
else {
lower_quantile = 0.05
upper_quantile = 0.95
medianv_quantile = 0.5
}
# extracts foi from stan fit
foi <- rstan::extract(seromodel_object$seromodel_fit, "foi", inc_warmup = FALSE)[[1]]
# generates central estimations
foi_central_estimates <- data.frame(
year = seromodel_object$exposure_years,
lower = apply(foi, 2, function(x) quantile(x, lower_quantile)),

upper = apply(foi, 2, function(x) quantile(x, upper_quantile)),

medianv = apply(foi, 2, function(x) quantile(x, medianv_quantile))
)
return(foi_central_estimates)
}

#' Method to extact a summary of the specified serological model object
#'
Expand Down Expand Up @@ -321,26 +310,24 @@ get_exposure_matrix <- function(serodata) {
#' }
#' @export
extract_seromodel_summary <- function(seromodel_object) {
foi_model <- seromodel_object$foi_model
serodata <- seromodel_object$serodata
#------- Loo estimates

loo_fit <- seromodel_object$loo_fit
loo_fit <- loo::loo(seromodel_object$seromodel_fit, save_psis = TRUE, "logLikelihood")
if (sum(is.na(loo_fit)) < 1) {
lll <- as.numeric((round(loo_fit$estimates[1, ], 2)))
} else {
lll <- c(-1e10, 0)
}
#-------
model_summary <- data.frame(
foi_model = foi_model,
dataset = serodata$survey[1],
country = serodata$country[1],
year = serodata$tsur[1],
test = serodata$test[1],
antibody = serodata$antibody[1],
n_sample = sum(serodata$total),
n_agec = length(serodata$age_mean_f),
n_iter = seromodel_object$n_iters,
foi_model = seromodel_object$seromodel_fit@model_name,
dataset = unique(seromodel_object$serodata$survey),
country = unique(seromodel_object$serodata$country),
year = unique(seromodel_object$serodata$tsur),
test = unique(seromodel_object$serodata$test),
antibody = unique(seromodel_object$serodata$antibody),
n_sample = sum(seromodel_object$serodata$total),
n_agec = length(seromodel_object$serodata$age_mean_f),
n_iter = seromodel_object$seromodel_fit@sim$iter,
elpd = lll[1],
se = lll[2],
converged = NA
Expand All @@ -360,7 +347,7 @@ extract_seromodel_summary <- function(seromodel_object) {
#'
#' This function computes the corresponding binomial confidence intervals for the obtained prevalence based on a fitting
#' of the Force-of-Infection \code{foi} for plotting an analysis purposes.
#' @param foi Object containing the information of the force of infection. It is obtained from \code{rstan::extract(seromodel_object$fit, "foi", inc_warmup = FALSE)[[1]]}.
#' @param foi Object containing the information of the force of infection. It is obtained from \code{rstan::extract(seromodel_object$seromodel, "foi", inc_warmup = FALSE)[[1]]}.
#' @param serodata A data frame containing the data from a seroprevalence survey. For further details refer to \link{run_seromodel}.
#' @param bin_data TBD
#' @return \code{prev_final}. The expanded prevalence data. This is used for plotting purposes in the \code{visualization} module.
Expand All @@ -370,7 +357,7 @@ extract_seromodel_summary <- function(seromodel_object) {
#' serodata <- prepare_serodata(chagas2012)
#' seromodel_object <- run_seromodel(serodata = serodata,
#' foi_model = "constant")
#' foi <- rstan::extract(seromodel_object$fit, "foi")[[1]]
#' foi <- rstan::extract(seromodel_object$seromodel_fit, "foi")[[1]]
#' get_prev_expanded <- function(foi, serodata)
#' }
#' @export
Expand Down
26 changes: 13 additions & 13 deletions R/visualisation.R
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ plot_seroprev <- function(serodata,
plot_seroprev_fitted <- function(seromodel_object,
size_text = 6) {

if (is.character(seromodel_object$fit) == FALSE) {
if (class(seromodel_object$fit@sim$samples) != "NULL" ) {
if (is.character(seromodel_object$seromodel_fit) == FALSE) {
if (class(seromodel_object$seromodel_fit@sim$samples) != "NULL" ) {

foi <- rstan::extract(seromodel_object$fit, "foi", inc_warmup = FALSE)[[1]]
foi <- rstan::extract(seromodel_object$seromodel_fit, "foi", inc_warmup = FALSE)[[1]]
prev_expanded <- get_prev_expanded(foi, serodata = seromodel_object$serodata, bin_data = TRUE)
prev_plot <-
ggplot2::ggplot(prev_expanded) +
Expand Down Expand Up @@ -148,14 +148,14 @@ plot_foi <- function(seromodel_object,
max_lambda = NA,
size_text = 25,
foi_sim = NULL) {
if (is.character(seromodel_object$fit) == FALSE) {
if (class(seromodel_object$fit@sim$samples) != "NULL") {
foi <- rstan::extract(seromodel_object$fit,
if (is.character(seromodel_object$seromodel_fit) == FALSE) {
if (class(seromodel_object$seromodel_fit@sim$samples) != "NULL") {
foi <- rstan::extract(seromodel_object$seromodel_fit,
"foi",
inc_warmup = FALSE)[[1]]

#-------- This bit is to get the actual length of the foi data
foi_data <- seromodel_object$foi_cent_est
foi_data <- get_foi_central_estimates(seromodel_object = seromodel_object)

#--------
foi_data$medianv[1] <- NA
Expand Down Expand Up @@ -243,8 +243,8 @@ plot_foi <- function(seromodel_object,
#' @export
plot_rhats <- function(seromodel_object,
size_text = 25) {
if (is.character(seromodel_object$fit) == FALSE) {
if (class(seromodel_object$fit@sim$samples) != "NULL") {
if (is.character(seromodel_object$seromodel_fit) == FALSE) {
if (class(seromodel_object$seromodel_fit@sim$samples) != "NULL") {
rhats <- get_table_rhats(seromodel_object)

rhats_plot <-
Expand Down Expand Up @@ -310,8 +310,8 @@ plot_seromodel <- function(seromodel_object,
max_lambda = NA,
size_text = 25,
foi_sim = NULL) {
if (is.character(seromodel_object$fit) == FALSE) {
if (class(seromodel_object$fit@sim$samples) != "NULL") {
if (is.character(seromodel_object$seromodel_fit) == FALSE) {
if (class(seromodel_object$seromodel_fit@sim$samples) != "NULL") {
prev_plot <- plot_seroprev_fitted(seromodel_object = seromodel_object,
size_text = size_text)

Expand All @@ -324,9 +324,9 @@ plot_seromodel <- function(seromodel_object,

rhats_plot <- plot_rhats(seromodel_object = seromodel_object,
size_text = size_text)

model_summary <- extract_seromodel_summary(seromodel_object = seromodel_object)
summary_table <- t(
dplyr::select(seromodel_object$model_summary,
dplyr::select(model_summary,
c('foi_model', 'dataset', 'elpd', 'se', 'converged')))
summary_plot <-
plot_info_table(summary_table, size_text = size_text)
Expand Down
28 changes: 28 additions & 0 deletions man/get_foi_central_estimates.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading