diff --git a/R/create.R b/R/create.R index 376423ffb..8be6e005e 100644 --- a/R/create.R +++ b/R/create.R @@ -654,8 +654,7 @@ create_stan_delays <- function(..., weight = 1) { ## number of different non-empty types type_n <- unlist(purrr::transpose(dot_args)$n) ## assign ID values to each type - ids <- rep(0L, length(type_n)) - ids[type_n > 0] <- seq_len(sum(type_n > 0)) + ids <- seq_along(type_n) names(ids) <- paste(names(type_n), "id", sep = "_") ## start consructing stan object diff --git a/R/dist.R b/R/dist.R index f691efed8..b74504a4f 100644 --- a/R/dist.R +++ b/R/dist.R @@ -978,13 +978,7 @@ dist_spec <- function(mean, sd = 0, mean_sd = 0, sd_sd = 0, ) if (length(pmf) == 0) { if (missing(mean)) { ## empty - ret <- c(ret, list( - n = 0, - n_p = 0, - n_np = 0, - np_pmf = numeric(0), - fixed = integer(0) - )) + pmf <- 1 } else { ## parametric fixed if (sd == 0) { ## delta pmf <- c(rep(0, mean), 1) @@ -1013,15 +1007,13 @@ dist_spec <- function(mean, sd = 0, mean_sd = 0, sd_sd = 0, } pmf <- pmf / sum(pmf) } - if (length(pmf) > 0) { - ret <- c(ret, list( - n = 1, - n_p = 0, - n_np = 1, - np_pmf = pmf, - fixed = 1L - )) - } + ret <- c(ret, list( + n = 1, + n_p = 0, + n_np = 1, + np_pmf = pmf, + fixed = 1L + )) } else { ret <- list( mean_mean = mean, diff --git a/R/estimate_infections.R b/R/estimate_infections.R index 73a29cf99..b6dcc7c30 100644 --- a/R/estimate_infections.R +++ b/R/estimate_infections.R @@ -65,6 +65,7 @@ #' @importFrom lubridate days #' @importFrom purrr transpose #' @importFrom futile.logger flog.threshold flog.warn flog.debug +#' @importFrom rstan extract #' @examples #' \donttest{ #' # set number of cores to use @@ -234,7 +235,8 @@ estimate_infections <- function(reported_cases, fit <- fit_model_with_vb(args, id = id) } # Extract parameters of interest from the fit - out <- extract_parameter_samples(fit, data, + samples <- extract(fit) + out <- extract_parameter_samples(samples, data, reported_inf_dates = reported_cases$date, reported_dates = reported_cases$date[-(1:data$seeding_time)] ) diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index c1824baf1..11349d0fe 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -628,6 +628,7 @@ forecast_secondary <- function(estimate, ) # extract data from stanfit data <- estimate$data + data$primary <- NULL # combined primary from data and input primary primary_fit <- estimate$predictions[, @@ -643,19 +644,20 @@ forecast_secondary <- function(estimate, data.table::setorderv(primary_fit, c("sample", "date")) # update data with primary samples and day of week - data$primary <- t( + data$primary_samples <- t( matrix(primary_fit$value, ncol = length(unique(primary_fit$sample))) ) data$day_of_week <- add_day_of_week( unique(primary_fit$date), data$week_effect ) - data$n <- nrow(data$primary) - data$t <- ncol(data$primary) + data$n <- nrow(data$primary_samples) + data$t <- ncol(data$primary_samples) data$h <- nrow(primary[sample == min(sample)]) # extract samples for posterior of estimates posterior_samples <- sample(seq_len(data$n), data$n, replace = TRUE) # nolint draws <- purrr::map(draws, ~ as.matrix(.[posterior_samples, ])) + names(draws) <- paste0(names(draws), "_samples") # combine with data data <- c(data, draws) @@ -666,8 +668,10 @@ forecast_secondary <- function(estimate, # allocate empty parameters data <- allocate_empty( - data, c("frac_obs", "delay_mean", "delay_sd", "rep_phi"), - n = data$n + data, c( + "frac_obs_samples", "delay_mean_samples", "delay_sd_samples", + "rep_phi_samples" + ), n = data$n ) data$all_dates <- as.integer(all_dates) ## simulate diff --git a/R/extract.R b/R/extract.R index 88df1b594..54ebb278e 100644 --- a/R/extract.R +++ b/R/extract.R @@ -60,7 +60,8 @@ extract_static_parameter <- function(param, samples) { #' Extracts a custom set of parameters from a stan object and adds #' stratification and dates where appropriate. #' -#' @param stan_fit A fit Stan model as returned by `rstan:sampling`. +#' @param samples A list of arrays containing samples extracted from a stan +#' model using \code{\link[rsta]{extract}} #' #' @param data A list of the data supplied to the `rstan::sampling` call. #' @@ -77,14 +78,10 @@ extract_static_parameter <- function(param, samples) { #' #' @return A list of dataframes each containing the posterior of a parameter #' @author Sam Abbott -#' @importFrom rstan extract #' @importFrom data.table data.table -extract_parameter_samples <- function(stan_fit, data, reported_dates, +extract_parameter_samples <- function(samples, data, reported_dates, reported_inf_dates, drop_length_1 = FALSE, merge = FALSE) { - # extract sample from stan object - samples <- rstan::extract(stan_fit) - ## drop initial length 1 dimensions if requested if (drop_length_1) { samples <- lapply(samples, function(x) { diff --git a/R/simulate_infections.R b/R/simulate_infections.R index ada43c60c..b6bfaa837 100644 --- a/R/simulate_infections.R +++ b/R/simulate_infections.R @@ -34,6 +34,7 @@ #' @importFrom progressr with_progress progressor #' @importFrom data.table rbindlist as.data.table #' @importFrom lubridate days +#' @importFrom rstan extract #' @return A list of output as returned by [estimate_infections()] but based on #' results from the specified scenario rather than fitting. #' @export @@ -186,13 +187,17 @@ simulate_infections <- function(estimates, shift, dates, nstart, nend) { # extract batch samples from draws draws <- map(draws, ~ as.matrix(.[nstart:nend, ])) + names(draws) <- paste0(names(draws), "_samples") ## prepare data for stan command - data <- c(list(n = dim(draws$R)[1]), draws, estimates$args) + data <- c(list(n = dim(draws$R_samples)[1]), draws, estimates$args) ## allocate empty parameters data <- allocate_empty( - data, c("frac_obs", "delay_mean", "delay_sd", "rep_phi"), + data, c( + "frac_obs_samples", "delay_mean_samples", "delay_sd_samples", + "rep_phi_samples" + ), n = data$n ) @@ -203,8 +208,12 @@ simulate_infections <- function(estimates, algorithm = "Fixed_param", refresh = 0 ) + ## extract sample from stan object + samples <- extract(sims) + names(samples) <- sub("^sim_", "", names(samples)) + names(data) <- sub("_samples$", "", names(data)) - out <- extract_parameter_samples(sims, data, + out <- extract_parameter_samples(samples, data, reported_inf_dates = dates, reported_dates = dates[-(1:shift)], drop_length_1 = TRUE, merge = TRUE @@ -251,6 +260,8 @@ simulate_infections <- function(estimates, out <- transpose(out) out <- map(out, rbindlist) + names(out) <- sub("^sim_", "", names(out)) + ## format output format_out <- format_fit( posterior_samples = out, diff --git a/inst/stan/chunks/R_to_growth.stan b/inst/stan/chunks/R_to_growth.stan new file mode 100644 index 000000000..c3d39f38b --- /dev/null +++ b/inst/stan/chunks/R_to_growth.stan @@ -0,0 +1,3 @@ +real gt_mean = rev_pmf_mean(gt_rev_pmf, 1); +real gt_var = rev_pmf_var(gt_rev_pmf, 1, gt_mean); +r = R_to_growth(R, gt_mean, gt_var); diff --git a/inst/stan/chunks/calculate_secondary.stan b/inst/stan/chunks/calculate_secondary.stan new file mode 100644 index 000000000..113bd6e86 --- /dev/null +++ b/inst/stan/chunks/calculate_secondary.stan @@ -0,0 +1,4 @@ +reports = calculate_secondary( + primary, obs, frac_obs, delay_rev_pmf, cumulative, + historic, primary_hist_additive, current, primary_current_additive, predict +); diff --git a/inst/stan/chunks/convolve_to_report.stan b/inst/stan/chunks/convolve_to_report.stan new file mode 100644 index 000000000..24f1db430 --- /dev/null +++ b/inst/stan/chunks/convolve_to_report.stan @@ -0,0 +1 @@ +reports = convolve_to_report(infections, delay_rev_pmf, seeding_time); diff --git a/inst/stan/chunks/day_of_week_effect.stan b/inst/stan/chunks/day_of_week_effect.stan new file mode 100644 index 000000000..9893460fe --- /dev/null +++ b/inst/stan/chunks/day_of_week_effect.stan @@ -0,0 +1,3 @@ + if (week_effect > 1) { + reports = day_of_week_effect(reports, day_of_week, day_of_week_simplex); + } diff --git a/inst/stan/chunks/delay_rev_pmf.stan b/inst/stan/chunks/delay_rev_pmf.stan new file mode 100644 index 000000000..f0bba3e0f --- /dev/null +++ b/inst/stan/chunks/delay_rev_pmf.stan @@ -0,0 +1,6 @@ +vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf( + delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, + delay_types_groups, delay_max, delay_np_pmf, + delay_np_pmf_groups, delay_mean, delay_sd, delay_dist, + 0, 1, 0 +); diff --git a/inst/stan/chunks/delay_type_max.stan b/inst/stan/chunks/delay_type_max.stan new file mode 100644 index 000000000..7e9ea0696 --- /dev/null +++ b/inst/stan/chunks/delay_type_max.stan @@ -0,0 +1,12 @@ +array[delay_types] int delay_type_max; +for (i in 1:delay_types) { + delay_type_max[i] = 0; + for (j in delay_types_groups[i]:(delay_types_groups[i + 1] - 1)) { + if (delay_types_p[j]) { // parametric + delay_type_max[i] += delay_max[delay_types_id[j]]; + } else { // nonparametric + delay_type_max[i] += delay_np_pmf_groups[delay_types_id[j] + 1] - + delay_np_pmf_groups[delay_types_id[j]] - 1; + } + } +} diff --git a/inst/stan/chunks/delays_lp.stan b/inst/stan/chunks/delays_lp.stan new file mode 100644 index 000000000..4a899d1ef --- /dev/null +++ b/inst/stan/chunks/delays_lp.stan @@ -0,0 +1,4 @@ +delays_lp( + delay_mean, delay_mean_mean, delay_mean_sd, delay_sd, delay_sd_mean, + delay_sd_sd, delay_dist, delay_weight +); diff --git a/inst/stan/chunks/generate_infections.stan b/inst/stan/chunks/generate_infections.stan new file mode 100644 index 000000000..340c5bb1a --- /dev/null +++ b/inst/stan/chunks/generate_infections.stan @@ -0,0 +1,4 @@ +infections = generate_infections( + R, seeding_time, gt_rev_pmf, initial_infections, initial_growth, pop, + future_time +); diff --git a/inst/stan/chunks/gt_rev_pmf.stan b/inst/stan/chunks/gt_rev_pmf.stan new file mode 100644 index 000000000..dbc8336b6 --- /dev/null +++ b/inst/stan/chunks/gt_rev_pmf.stan @@ -0,0 +1,6 @@ +vector[delay_type_max[gt_id] + 1] gt_rev_pmf = get_delay_rev_pmf( + gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id, + delay_types_groups, delay_max, delay_np_pmf, + delay_np_pmf_groups, delay_mean, delay_sd, delay_dist, + 1, 1, 0 +); diff --git a/inst/stan/chunks/impute_reports.stan b/inst/stan/chunks/impute_reports.stan new file mode 100644 index 000000000..2f625f27c --- /dev/null +++ b/inst/stan/chunks/impute_reports.stan @@ -0,0 +1 @@ +imputed_reports = report_rng(reports, rep_phi, model_type); diff --git a/inst/stan/chunks/obs_scale_lp.stan b/inst/stan/chunks/obs_scale_lp.stan new file mode 100644 index 000000000..fb6cc77bf --- /dev/null +++ b/inst/stan/chunks/obs_scale_lp.stan @@ -0,0 +1,3 @@ +if (obs_scale) { + frac_obs[1] ~ normal(obs_scale_mean, obs_scale_sd) T[0, 1]; +} diff --git a/inst/stan/chunks/report_log_lik.stan b/inst/stan/chunks/report_log_lik.stan new file mode 100644 index 000000000..73013ddf6 --- /dev/null +++ b/inst/stan/chunks/report_log_lik.stan @@ -0,0 +1,3 @@ +log_lik = report_log_lik( + cases, obs_reports, rep_phi, model_type, obs_weight +); diff --git a/inst/stan/chunks/report_lp.stan b/inst/stan/chunks/report_lp.stan new file mode 100644 index 000000000..3255284e8 --- /dev/null +++ b/inst/stan/chunks/report_lp.stan @@ -0,0 +1,3 @@ +report_lp( + cases, obs_reports, rep_phi, phi_mean, phi_sd, model_type, obs_weight +); diff --git a/inst/stan/chunks/scale_obs.stan b/inst/stan/chunks/scale_obs.stan new file mode 100644 index 000000000..380418c7b --- /dev/null +++ b/inst/stan/chunks/scale_obs.stan @@ -0,0 +1,3 @@ +if (obs_scale) { + reports = scale_obs(reports, frac_obs[1]); +} diff --git a/inst/stan/chunks/sim_vars.stan b/inst/stan/chunks/sim_vars.stan new file mode 100644 index 000000000..2c424d8d4 --- /dev/null +++ b/inst/stan/chunks/sim_vars.stan @@ -0,0 +1,7 @@ +array[delay_n_p] real delay_mean = delay_mean_samples[i]; +array[delay_n_p] real delay_sd = delay_sd_samples[i]; +array[obs_scale] real frac_obs = frac_obs_samples[i]; +array[model_type] real rep_phi = rep_phi_samples[i]; +vector[week_effect] day_of_week_simplex = to_vector( + day_of_week_simplex_samples[i] +); diff --git a/inst/stan/chunks/trunc_rev_cmf.stan b/inst/stan/chunks/trunc_rev_cmf.stan new file mode 100644 index 000000000..7415d4ba7 --- /dev/null +++ b/inst/stan/chunks/trunc_rev_cmf.stan @@ -0,0 +1,6 @@ +vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf( + trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, + delay_types_groups, delay_max, delay_np_pmf, + delay_np_pmf_groups, delay_mean, delay_sd, delay_dist, + 0, 1, 1 +); diff --git a/inst/stan/data/simulation_delays.stan b/inst/stan/data/simulation_delays.stan index 0ceeedcaa..9446e84b9 100644 --- a/inst/stan/data/simulation_delays.stan +++ b/inst/stan/data/simulation_delays.stan @@ -1,8 +1,8 @@ int delay_n; // number of delay distribution distributions int delay_n_p; // number of parametric delay distributions int delay_n_np; // number of nonparametric delay distributions - array[n, delay_n_p] real delay_mean; // prior mean of mean delay distribution - array[n, delay_n_p] real delay_sd; // prior sd of sd of delay distribution + array[n, delay_n_p] real delay_mean_samples; // prior mean of mean delay distribution + array[n, delay_n_p] real delay_sd_samples; // prior sd of sd of delay distribution array[delay_n_p] int delay_max; // maximum delay distribution array[delay_n_p] int delay_dist; // 0 = lognormal; 1 = gamma int delay_np_pmf_length; // number of nonparametric pmf elements diff --git a/inst/stan/data/simulation_observation_model.stan b/inst/stan/data/simulation_observation_model.stan index c8cab6b35..5e4e92576 100644 --- a/inst/stan/data/simulation_observation_model.stan +++ b/inst/stan/data/simulation_observation_model.stan @@ -1,8 +1,8 @@ array[t - seeding_time] int day_of_week; // day of the week indicator (1 - 7) int week_effect; // should a day of the week effect be estimated - array[n, week_effect] real day_of_week_simplex; + array[n, week_effect] real day_of_week_simplex_samples; int obs_scale; - array[n, obs_scale] real frac_obs; + array[n, obs_scale] real frac_obs_samples; int model_type; - array[n, model_type] real rep_phi; // overdispersion of the reporting process + array[n, model_type] real rep_phi_samples; // overdispersion of the reporting process int trunc_id; // id of truncation diff --git a/inst/stan/data/simulation_rt.stan b/inst/stan/data/simulation_rt.stan index a97300451..ab1e39efc 100644 --- a/inst/stan/data/simulation_rt.stan +++ b/inst/stan/data/simulation_rt.stan @@ -1,7 +1,7 @@ - array[seeding_time ? n : 0, 1] real initial_infections; // initial logged infections - array[seeding_time > 1 ? n : 0, 1] real initial_growth; //initial growth +array[seeding_time ? n : 0, 1] real initial_infections_samples; // initial logged infections +array[seeding_time > 1 ? n : 0, 1] real initial_growth_samples; //initial growth - matrix[n, t - seeding_time] R; // reproduction number - int pop; // susceptible population +array[n, t - seeding_time] real R_samples; // reproduction number +int pop; // susceptible population - int gt_id; // id of generation time +int gt_id; // id of generation time diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index a97bb57e5..1396f59ad 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -30,13 +30,12 @@ transformed data{ real r_logmean = log(r_mean^2 / sqrt(r_sd^2 + r_mean^2)); real r_logsd = sqrt(log(1 + (r_sd^2 / r_mean^2))); - array[delay_types] int delay_type_max = get_delay_type_max( - delay_types, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf_groups - ); +#include chunks/delay_type_max.stan } parameters{ +#include params/delays.stan +#include params/observation_model.stan // gaussian process array[fixed ? 0 : 1] real rho; // length scale of noise GP array[fixed ? 0 : 1] real alpha; // scale of of noise GP @@ -47,12 +46,6 @@ parameters{ array[estimate_r && seeding_time > 1 ? 1 : 0] real initial_growth; // seed growth rate array[bp_n > 0 ? 1 : 0] real bp_sd; // standard deviation of breakpoint effect array[bp_n] real bp_effects; // Rt breakpoint effects - // observation model - array[delay_n_p] real delay_mean; // mean of delays - array[delay_n_p] real delay_sd; // sd of delays - simplex[week_effect] day_of_week_simplex;// day of week reporting effect - array[obs_scale] real frac_obs; // fraction of cases that are ultimately observed - array[model_type] real rep_phi; // overdispersion of the reporting process } transformed parameters { @@ -61,26 +54,17 @@ transformed parameters { vector[t] infections; // latent infections vector[ot_h] reports; // estimated reported cases vector[ot] obs_reports; // observed estimated reported cases - vector[estimate_r * (delay_type_max[gt_id] + 1)] gt_rev_pmf; // GP in noise - spectral densities if (!fixed) { noise = update_gp(PHI, M, L, alpha[1], rho[1], eta, gp_type); } // Estimate latent infections if (estimate_r) { - gt_rev_pmf = get_delay_rev_pmf( - gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_mean, delay_sd, delay_dist, - 1, 1, 0 - ); +#include chunks/gt_rev_pmf.stan R = update_Rt( ot_h, log_R[estimate_r], noise, breakpoints, bp_effects, stationary ); - infections = generate_infections( - R, seeding_time, gt_rev_pmf, initial_infections, initial_growth, pop, - future_time - ); +#include chunks/generate_infections.stan } else { // via deconvolution infections = deconvolve_infections( @@ -88,37 +72,19 @@ transformed parameters { ); } // convolve from latent infections to mean of observations - if (delay_id) { - vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf( - delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_mean, delay_sd, delay_dist, - 0, 1, 0 - ); - reports = convolve_to_report(infections, delay_rev_pmf, seeding_time); - } else { - reports = infections[(seeding_time + 1):t]; + { +#include chunks/delay_rev_pmf.stan +#include chunks/convolve_to_report.stan } // weekly reporting effect - if (week_effect > 1) { - reports = day_of_week_effect(reports, day_of_week, day_of_week_simplex); - } +#include chunks/day_of_week_effect.stan // scaling of reported cases by fraction observed - if (obs_scale) { - reports = scale_obs(reports, frac_obs[1]); - } +#include chunks/scale_obs.stan // truncate near time cases to observed reports - if (trunc_id) { - vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf( - trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_mean, delay_sd, delay_dist, - 0, 1, 1 - ); + { +#include chunks/trunc_rev_cmf.stan obs_reports = truncate(reports[1:ot], trunc_rev_cmf, 0); - } else { - obs_reports = reports[1:ot]; - } + } } model { @@ -129,11 +95,7 @@ model { ); } // penalised priors for delay distributions - delays_lp( - delay_mean, delay_mean_mean, - delay_mean_sd, delay_sd, delay_sd_mean, delay_sd_sd, - delay_dist, delay_weight - ); +#include chunks/delays_lp.stan if (estimate_r) { // priors on Rt rt_lp( @@ -142,14 +104,10 @@ model { ); } // prior observation scaling - if (obs_scale) { - frac_obs[1] ~ normal(obs_scale_mean, obs_scale_sd) T[0, 1]; - } +#include chunks/obs_scale_lp.stan // observed reports from mean of reports (update likelihood) if (likelihood) { - report_lp( - cases, obs_reports, rep_phi, phi_mean, phi_sd, model_type, obs_weight - ); +#include chunks/report_lp.stan } } @@ -157,14 +115,11 @@ generated quantities { array[ot_h] int imputed_reports; vector[estimate_r > 0 ? 0: ot_h] gen_R; array[ot_h] real r; - real gt_mean; - real gt_var; vector[return_likelihood ? ot : 0] log_lik; if (estimate_r){ +#include chunks/gt_rev_pmf.stan // estimate growth from estimated Rt - gt_mean = rev_pmf_mean(gt_rev_pmf, 1); - gt_var = rev_pmf_var(gt_rev_pmf, 1, gt_mean); - r = R_to_growth(R, gt_mean, gt_var); +#include chunks/R_to_growth.stan } else { // sample generation time array[delay_n_p] real delay_mean_sample = @@ -177,8 +132,8 @@ generated quantities { delay_np_pmf_groups, delay_mean_sample, delay_sd_sample, delay_dist, 1, 1, 0 ); - gt_mean = rev_pmf_mean(sampled_gt_rev_pmf, 1); - gt_var = rev_pmf_var(sampled_gt_rev_pmf, 1, gt_mean); + real gt_mean = rev_pmf_mean(sampled_gt_rev_pmf, 1); + real gt_var = rev_pmf_var(sampled_gt_rev_pmf, 1, gt_mean); // calculate Rt using infections and generation time gen_R = calculate_Rt( infections, seeding_time, sampled_gt_rev_pmf, rt_half_window @@ -187,11 +142,9 @@ generated quantities { r = R_to_growth(gen_R, gt_mean, gt_var); } // simulate reported cases - imputed_reports = report_rng(reports, rep_phi, model_type); +#include chunks/impute_reports.stan // log likelihood of model if (return_likelihood) { - log_lik = report_log_lik( - cases, obs_reports, rep_phi, model_type, obs_weight - ); +#include chunks/report_log_lik.stan } } diff --git a/inst/stan/estimate_secondary.stan b/inst/stan/estimate_secondary.stan index e2209349b..5a3fa9dd3 100644 --- a/inst/stan/estimate_secondary.stan +++ b/inst/stan/estimate_secondary.stan @@ -17,74 +17,40 @@ data { } transformed data{ - array[delay_types] int delay_type_max = get_delay_type_max( - delay_types, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf_groups - ); + int predict = t; +#include chunks/delay_type_max.stan } parameters{ // observation model - array[delay_n_p] real delay_mean; - array[delay_n_p] real delay_sd; // sd of delays - simplex[week_effect] day_of_week_simplex; // day of week reporting effect - array[obs_scale] real frac_obs; // fraction of cases that are ultimately observed - array[model_type] real rep_phi; // overdispersion of the reporting process +#include params/delays.stan +#include params/observation_model.stan } transformed parameters { - vector[t] secondary; - // calculate secondary reports from primary + vector[t] reports; // secondary reports { - vector[delay_type_max[delay_id] + 1] delay_rev_pmf; - if (delay_id) { - delay_rev_pmf = get_delay_rev_pmf( - delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_mean, delay_sd, delay_dist, - 0, 1, 0 - ); - } else { - delay_rev_pmf = to_vector({ 1 }); - } - secondary = calculate_secondary( - primary, obs, frac_obs, delay_rev_pmf, cumulative, historic, - primary_hist_additive, current, primary_current_additive, t - ); +#include chunks/delay_rev_pmf.stan +#include chunks/calculate_secondary.stan + // weekly reporting effect +#include chunks/day_of_week_effect.stan + // truncate near time cases to observed reports +#include chunks/trunc_rev_cmf.stan + reports = truncate(reports, trunc_rev_cmf, 0); } - - // weekly reporting effect - if (week_effect > 1) { - secondary = day_of_week_effect(secondary, day_of_week, day_of_week_simplex); - } - // truncate near time cases to observed reports - if (trunc_id) { - vector[delay_type_max[trunc_id]] trunc_rev_cmf = get_delay_rev_pmf( - trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_mean, delay_sd, delay_dist, - 0, 1, 1 - ); - secondary = truncate(secondary, trunc_rev_cmf, 0); - } } model { // penalised priors for delay distributions - delays_lp( - delay_mean, delay_mean_mean, delay_mean_sd, delay_sd, delay_sd_mean, - delay_sd_sd, delay_dist, delay_weight - ); - +#include chunks/delays_lp.stan // prior primary report scaling - if (obs_scale) { - frac_obs[1] ~ normal(obs_scale_mean, obs_scale_sd) T[0, 1]; - } +#include chunks/obs_scale_lp.stan // observed secondary reports from mean of secondary reports (update likelihood) if (likelihood) { - report_lp(obs[(burn_in + 1):t], secondary[(burn_in + 1):t], - rep_phi, phi_mean, phi_sd, model_type, 1); + array[t - burn_in] int cases = obs[(burn_in + 1):t]; + vector[t - burn_in] obs_reports = reports[(burn_in + 1):t]; +#include chunks/report_lp.stan } } @@ -92,10 +58,11 @@ generated quantities { array[t - burn_in] int sim_secondary; vector[return_likelihood > 1 ? t - burn_in : 0] log_lik; // simulate secondary reports - sim_secondary = report_rng(secondary[(burn_in + 1):t], rep_phi, model_type); + sim_secondary = report_rng(reports[(burn_in + 1):t], rep_phi, model_type); // log likelihood of model if (return_likelihood) { - log_lik = report_log_lik(obs[(burn_in + 1):t], secondary[(burn_in + 1):t], - rep_phi, model_type, obs_weight); + array[t - burn_in] int cases = obs[(burn_in + 1):t]; + vector[t - burn_in] obs_reports = reports[(burn_in + 1):t]; +#include chunks/report_log_lik.stan } } diff --git a/inst/stan/estimate_truncation.stan b/inst/stan/estimate_truncation.stan index b5ffd673b..8fbb24b11 100644 --- a/inst/stan/estimate_truncation.stan +++ b/inst/stan/estimate_truncation.stan @@ -16,11 +16,7 @@ transformed data{ array[obs_sets] int end_t; array[obs_sets] int start_t; - array[delay_types] int delay_type_max; - delay_type_max = get_delay_type_max( - delay_types, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf_groups - ); +#include chunks/delay_type_max.stan for (i in 1:obs_sets) { end_t[i] = t - obs_dist[i]; @@ -28,8 +24,7 @@ transformed data{ } } parameters { - array[delay_n_p] real delay_mean; - array[delay_n_p] real delay_sd; // sd of delays +#include params/delays.stan real phi; real sigma; } @@ -38,12 +33,7 @@ transformed parameters{ matrix[delay_type_max[trunc_id] + 1, obs_sets - 1] trunc_obs = rep_matrix( 0, delay_type_max[trunc_id] + 1, obs_sets - 1 ); - vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf( - trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_mean, delay_sd, delay_dist, - 0, 1, 1 - ); +#include chunks/trunc_rev_cmf.stan { vector[t] last_obs; // reconstruct latest data without truncation @@ -59,11 +49,9 @@ transformed parameters{ } model { // priors for the log normal truncation distribution - delays_lp( - delay_mean, delay_mean_mean, delay_mean_sd, delay_sd, delay_sd_mean, - delay_sd_sd, delay_dist, delay_weight - ); - + +#include chunks/delays_lp.stan + phi ~ normal(0, 1) T[0,]; sigma ~ normal(0, 1) T[0,]; diff --git a/inst/stan/functions/delays.stan b/inst/stan/functions/delays.stan index 203594efa..92ad1c436 100644 --- a/inst/stan/functions/delays.stan +++ b/inst/stan/functions/delays.stan @@ -1,22 +1,3 @@ -array[] int get_delay_type_max( - int delay_types, array[] int delay_types_p, array[] int delay_types_id, - array[] int delay_types_groups, array[] int delay_max, array[] int delay_np_pmf_groups -) { - array[delay_types] int ret; - for (i in 1:delay_types) { - ret[i] = 0; - for (j in delay_types_groups[i]:(delay_types_groups[i + 1] - 1)) { - if (delay_types_p[j]) { // parametric - ret[i] += delay_max[delay_types_id[j]]; - } else { // nonparametric - ret[i] += delay_np_pmf_groups[delay_types_id[j] + 1] - - delay_np_pmf_groups[delay_types_id[j]] - 1; - } - } - } - return ret; -} - vector get_delay_rev_pmf( int delay_id, int len, array[] int delay_types_p, array[] int delay_types_id, array[] int delay_types_groups, array[] int delay_max, diff --git a/inst/stan/params/delays.stan b/inst/stan/params/delays.stan new file mode 100644 index 000000000..9685243c4 --- /dev/null +++ b/inst/stan/params/delays.stan @@ -0,0 +1,2 @@ +array[delay_n_p] real delay_mean; // mean of delays +array[delay_n_p] real delay_sd; // sd of delays diff --git a/inst/stan/params/observation_model.stan b/inst/stan/params/observation_model.stan new file mode 100644 index 000000000..2362d0b9a --- /dev/null +++ b/inst/stan/params/observation_model.stan @@ -0,0 +1,3 @@ +simplex[week_effect] day_of_week_simplex;// day of week reporting effect +array[obs_scale] real frac_obs; // fraction of cases that are ultimately observed +array[model_type] real rep_phi; // overdispersion of the reporting process diff --git a/inst/stan/simulate_infections.stan b/inst/stan/simulate_infections.stan index d8448d047..1f4cfab76 100644 --- a/inst/stan/simulate_infections.stan +++ b/inst/stan/simulate_infections.stan @@ -24,66 +24,45 @@ data { } transformed data { - array[delay_types] int delay_type_max = get_delay_type_max( - delay_types, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf_groups - ); +#include chunks/delay_type_max.stan } generated quantities { // generated quantities - matrix[n, t] infections; //latent infections - matrix[n, t - seeding_time] reports; // observed cases - array[n, t - seeding_time] int imputed_reports; - array[n, t - seeding_time] real r; + array[n, t] real sim_infections; //latent infections + array[n, t - seeding_time] real sim_reports; // observed cases + array[n, t - seeding_time] int sim_imputed_reports; + array[n, t - seeding_time] real sim_r; for (i in 1:n) { // generate infections from Rt trace - vector[delay_type_max[gt_id] + 1] gt_rev_pmf; - gt_rev_pmf = get_delay_rev_pmf( - gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_mean[i], delay_sd[i], delay_dist, - 1, 1, 0 - ); - - infections[i] = to_row_vector(generate_infections( - to_vector(R[i]), seeding_time, gt_rev_pmf, initial_infections[i], - initial_growth[i], pop, future_time - )); - - if (delay_id) { - vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf( - delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_mean[i], delay_sd[i], delay_dist, - 0, 1, 0 - ); - // convolve from latent infections to mean of observations - reports[i] = to_row_vector(convolve_to_report( - to_vector(infections[i]), delay_rev_pmf, seeding_time) - ); - } else { - reports[i] = to_row_vector(infections[(seeding_time + 1):t]); +#include chunks/sim_vars.stan + vector[t - seeding_time] R = to_vector(R_samples[i]); + array[seeding_time > 0] real initial_infections; + array[seeding_time > 1] real initial_growth; + if (seeding_time > 0) { + initial_infections = initial_infections_samples[i]; + if (seeding_time > 1 ) { + initial_growth = initial_growth_samples[i]; + } } - // weekly reporting effect - if (week_effect > 1) { - reports[i] = to_row_vector( - day_of_week_effect(to_vector(reports[i]), day_of_week, - to_vector(day_of_week_simplex[i]))); + { + vector[t] infections; + vector[t - seeding_time] reports; + array[t - seeding_time] int imputed_reports; + array[t - seeding_time] real r; +#include chunks/gt_rev_pmf.stan +#include chunks/delay_rev_pmf.stan +#include chunks/generate_infections.stan +#include chunks/convolve_to_report.stan +#include chunks/day_of_week_effect.stan +#include chunks/scale_obs.stan +#include chunks/impute_reports.stan +#include chunks/R_to_growth.stan + sim_infections[i] = to_array_1d(infections); + sim_reports[i] = to_array_1d(reports); + sim_imputed_reports[i] = to_array_1d(imputed_reports); + sim_r[i] = r; } - // scale observations - if (obs_scale) { - reports[i] = to_row_vector(scale_obs(to_vector(reports[i]), frac_obs[i, 1])); - } - // simulate reported cases - imputed_reports[i] = report_rng( - to_vector(reports[i]), rep_phi[i], model_type - ); - { - real gt_mean = rev_pmf_mean(gt_rev_pmf, 1); - real gt_var = rev_pmf_var(gt_rev_pmf, 1, gt_mean); - r[i] = R_to_growth(to_vector(R[i]), gt_mean, gt_var); - } } } diff --git a/inst/stan/simulate_secondary.stan b/inst/stan/simulate_secondary.stan index 36e6a0edd..df671469d 100644 --- a/inst/stan/simulate_secondary.stan +++ b/inst/stan/simulate_secondary.stan @@ -14,7 +14,7 @@ data { int all_dates; // should all dates have simulations returned // secondary model specific data array[t - h] int obs; // observed secondary data - matrix[n, t] primary; // observed primary data + matrix[n, t] primary_samples; // observed primary data #include data/secondary.stan // delay from infection to report #include data/simulation_delays.stan @@ -23,37 +23,26 @@ data { } transformed data { - array[delay_types] int delay_type_max = get_delay_type_max( - delay_types, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf_groups - ); + int predict = t - h; +#include chunks/delay_type_max.stan } generated quantities { array[n, all_dates ? t : h] int sim_secondary; for (i in 1:n) { - vector[t] secondary; - vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf( - delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id, - delay_types_groups, delay_max, delay_np_pmf, - delay_np_pmf_groups, delay_mean[i], delay_sd[i], delay_dist, - 0, 1, 0 - ); + vector[t] reports; + vector[t] primary = to_vector(primary_samples[i]); +#include chunks/sim_vars.stan +#include chunks/delay_rev_pmf.stan // calculate secondary reports from primary - secondary = - calculate_secondary( - to_vector(primary[i]), obs, frac_obs[i], delay_rev_pmf, cumulative, - historic, primary_hist_additive, current, primary_current_additive, - t - h + 1 - ); +#include chunks/calculate_secondary.stan // weekly reporting effect - if (week_effect > 1) { - secondary = day_of_week_effect(secondary, day_of_week, to_vector(day_of_week_simplex[i])); - } +#include chunks/day_of_week_effect.stan + // simulate secondary reports sim_secondary[i] = report_rng( - tail(secondary, all_dates ? t : h), rep_phi[i], model_type + tail(reports, all_dates ? t : h), rep_phi, model_type ); } } diff --git a/man/extract_parameter_samples.Rd b/man/extract_parameter_samples.Rd index f6684d0e2..fb979ed8c 100644 --- a/man/extract_parameter_samples.Rd +++ b/man/extract_parameter_samples.Rd @@ -5,7 +5,7 @@ \title{Extract Parameter Samples from a Stan Model} \usage{ extract_parameter_samples( - stan_fit, + samples, data, reported_dates, reported_inf_dates, @@ -14,7 +14,8 @@ extract_parameter_samples( ) } \arguments{ -\item{stan_fit}{A fit Stan model as returned by \code{rstan:sampling}.} +\item{samples}{A list of arrays containing samples extracted from a stan +model using \code{\link[rsta]{extract}}} \item{data}{A list of the data supplied to the \code{rstan::sampling} call.} diff --git a/tests/testthat/test-delays.R b/tests/testthat/test-delays.R index e7daddf7e..2a62ee87c 100644 --- a/tests/testthat/test-delays.R +++ b/tests/testthat/test-delays.R @@ -18,48 +18,48 @@ delay_params <- test_that("generation times can be specified in different ways", { expect_equal( test_stan_delays(params = delay_params), - c(0, 1) + c(0, 1, 1, 1) ) expect_equal( test_stan_delays( generation_time = generation_time_opts(dist_spec(mean = 3)), params = delay_params ), - c(0, 0, 0, 1) + c(0, 0, 0, 1, 1, 1) ) expect_equal( round(test_stan_delays( generation_time = generation_time_opts(dist_spec(mean = 3, sd = 1, max = 4)), params = delay_params ), digits = 2), - c(0.02, 0.11, 0.22, 0.30, 0.35) + c(0.02, 0.11, 0.22, 0.30, 0.35, 1.00, 1.00) ) }) test_that("delay parameters can be specified in different ways", { expect_equal( - tail(test_stan_delays( + test_stan_delays( delays = delay_opts(dist_spec(mean = 3)), params = delay_params - ), n = -2), - c(0, 0, 0, 1) + ), + c(0, 1, 0, 0, 0, 1, 1) ) expect_equal( - tail(round(test_stan_delays( + round(test_stan_delays( delays = delay_opts(dist_spec(mean = 3, sd = 1, max = 4)), params = delay_params - ), digits = 2), n = -2), - c(0.02, 0.11, 0.22, 0.30, 0.35) + ), digits = 2), + c(0.00, 1.00, 0.02, 0.11, 0.22, 0.30, 0.35, 1.00) ) }) test_that("truncation parameters can be specified in different ways", { expect_equal( - tail(round(test_stan_delays( + round(test_stan_delays( truncation = trunc_opts(dist = dist_spec(mean = 3, sd = 1, max = 4)), params = delay_params - ), digits = 2), n = -2), - c(0.02, 0.11, 0.22, 0.30, 0.35) + ), digits = 2), + c(0.00, 1.00, 1.00, 0.02, 0.11, 0.22, 0.30, 0.35) ) }) diff --git a/tests/testthat/test-dist_spec.R b/tests/testthat/test-dist_spec.R index 45aa94672..791a25240 100644 --- a/tests/testthat/test-dist_spec.R +++ b/tests/testthat/test-dist_spec.R @@ -187,7 +187,7 @@ test_that("print.dist_spec correctly prints the parameters of the uncertain logn test_that("print.dist_spec correctly prints the parameters of an empty distribution", { empty <- dist_spec() - expect_output(print(empty), "Empty `dist_spec` distribution.") + expect_output(print(empty), "PMF \\[1\\].") }) test_that("print.dist_spec correctly prints the parameters of a combination of distributions", {