Skip to content

Commit

Permalink
implement different model types
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk committed Oct 13, 2022
1 parent e550465 commit 2406f38
Show file tree
Hide file tree
Showing 16 changed files with 124 additions and 106 deletions.
13 changes: 7 additions & 6 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,7 @@ create_gp_data <- function(gp = gp_opts(), data) {
#' create_obs_model(obs_opts(week_length = 3), dates = dates)
create_obs_model <- function(obs = obs_opts(), dates) {
data <- list(
model_type = ifelse(obs$family %in% "poisson", 0, 1),
obs_dist = ifelse(obs$family %in% "poisson", 0, 1),
phi_mean = obs$phi[1],
phi_sd = obs$phi[2],
week_effect = ifelse(obs$week_effect, obs$week_length, 1),
Expand Down Expand Up @@ -386,7 +386,7 @@ create_obs_model <- function(obs = obs_opts(), dates) {
create_stan_data <- function(reported_cases, generation_time,
rt, gp, obs, delays, horizon,
backcalc, shifted_cases,
truncation) {
truncation, process_model) {

## make sure we have at least max_gt seeding time
delays$seeding_time <- max(delays$seeding_time, generation_time$max)
Expand Down Expand Up @@ -429,7 +429,8 @@ create_stan_data <- function(reported_cases, generation_time,
gt_sd_mean = generation_time$sd,
gt_sd_sd = generation_time$sd_sd,
max_gt = generation_time$max,
burn_in = 0
burn_in = 0,
process_model = process_model
)
# add delay data
data <- c(data, delays)
Expand Down Expand Up @@ -530,7 +531,7 @@ create_initial_conditions <- function(data) {
)
out$alpha <- array(truncnorm::rtruncnorm(1, a = 0, mean = 0, sd = data$alpha_sd))
}
if (data$model_type == 1) {
if (data$obs_dist == 1) {
out$rep_phi <- array(
truncnorm::rtruncnorm(
1, a = 0, mean = data$phi_mean, sd = data$phi_sd / 10
Expand All @@ -542,10 +543,10 @@ create_initial_conditions <- function(data) {
if (data$seeding_time > 1) {
out$initial_growth <- array(rnorm(1, data$prior_growth, 0.01))
}
out$log_R <- array(rnorm(
out$base_cov <- rnorm(
n = 1, mean = convert_to_logmean(data$r_mean, data$r_sd),
sd = convert_to_logsd(data$r_mean, data$r_sd) * 0.1
))
)
if (data$gt_mean_sd > 0) {
out$gt_mean <- array(truncnorm::rtruncnorm(1,
a = 0, mean = data$gt_mean_mean,
Expand Down
8 changes: 7 additions & 1 deletion R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,7 @@
#' }
estimate_infections <- function(reported_cases,
generation_time,
model = "R",
delays = delay_opts(),
truncation = trunc_opts(),
rt = rt_opts(),
Expand Down Expand Up @@ -247,6 +248,10 @@ estimate_infections <- function(reported_cases,
)
reported_cases <- reported_cases[-(1:backcalc$prior_window)]

model_choices <- c("infections", "growth", "R")
model <- match.arg(model_choices, choices = model_choices)
process_model <- which(model == model_choices) - 1

# Define stan model parameters
data <- create_stan_data(
reported_cases = reported_cases,
Expand All @@ -258,7 +263,8 @@ estimate_infections <- function(reported_cases,
obs = obs,
backcalc = backcalc,
shifted_cases = shifted_cases$confirm,
horizon = horizon
horizon = horizon,
process_model = process_model
)

# Set up default settings
Expand Down
2 changes: 1 addition & 1 deletion R/extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates, reported_i
out$gt_sd <- extract_static_parameter("gt_sd", samples)
out$gt_sd <- out$gt_sd[, value := value.V1][, value.V1 := NULL]
}
if (data$model_type == 1) {
if (data$obs_dist == 1) {
out$reporting_overdispersion <- extract_static_parameter("rep_phi", samples)
out$reporting_overdispersion <- out$reporting_overdispersion[, value := value.V1][
,
Expand Down
3 changes: 3 additions & 0 deletions inst/stan/data/covariates.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
int process_model; // 0 = infections; 1 = growth; 2 = rt
int bp_n; // no of breakpoints (0 = no breakpoints)
int breakpoints[t - seeding_time]; // when do breakpoints occur
2 changes: 1 addition & 1 deletion inst/stan/data/observation_model.stan
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
int day_of_week[t - seeding_time]; // day of the week indicator (1 - 7)
int model_type; // type of model: 0 = poisson otherwise negative binomial
int obs_dist; // type of model: 0 = poisson otherwise negative binomial
real phi_mean; // Mean and sd of the normal prior for the
real phi_sd; // reporting process
int week_effect; // length of week effect
Expand Down
2 changes: 1 addition & 1 deletion inst/stan/data/observations.stan
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
int t; // unobserved time
int seeding_time; // time period used for seeding and not observed
int<lower = 1> seeding_time; // time period used for seeding and not observed
int horizon; // forecast horizon
int future_time; // time in future for Rt
int<lower = 0> cases[t - horizon - seeding_time]; // observed cases
Expand Down
2 changes: 0 additions & 2 deletions inst/stan/data/rt.stan
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
real prior_growth; // prior on initial growth rate
real <lower = 0> r_mean; // prior mean of reproduction number
real <lower = 0> r_sd; // prior standard deviation of reproduction number
int bp_n; // no of breakpoints (0 = no breakpoints)
int breakpoints[t - seeding_time]; // when do breakpoints occur
int future_fixed; // is underlying future Rt assumed to be fixed
int fixed_from; // Reference date for when Rt estimation should be fixed
int pop; // Initial susceptible population
4 changes: 2 additions & 2 deletions inst/stan/data/simulation_observation_model.stan
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,5 @@
real<lower = 0> day_of_week_simplex[n, week_effect];
int obs_scale;
real frac_obs[n, obs_scale];
int model_type;
real<lower = 0> rep_phi[n, model_type]; // overdispersion of the reporting process
int obs_dist;
real<lower = 0> rep_phi[n, obs_dist]; // overdispersion of the reporting process
60 changes: 35 additions & 25 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ functions {
#include functions/pmfs.stan
#include functions/convolve.stan
#include functions/gaussian_process.stan
#include functions/rt.stan
#include functions/covariates.stan
#include functions/infections.stan
#include functions/observation_model.stan
#include functions/generated_quantities.stan
Expand All @@ -12,6 +12,7 @@ functions {
data {
#include data/observations.stan
#include data/delays.stan
#include data/covariates.stan
#include data/gaussian_process.stan
#include data/generation_time.stan
#include data/rt.stan
Expand All @@ -36,8 +37,7 @@ parameters{
real<lower = ls_min,upper=ls_max> rho[fixed ? 0 : 1]; // length scale of noise GP
real<lower = 0> alpha[fixed ? 0 : 1]; // scale of of noise GP
vector[fixed ? 0 : M] eta; // unconstrained noise
// Rt
vector[estimate_r] log_R; // baseline reproduction number estimate (log)
real base_cov; // covariate (R/r/inf)
real initial_infections[estimate_r] ; // seed infections
real initial_growth[estimate_r && seeding_time > 1 ? 1 : 0]; // seed growth rate
real<lower = 0, upper = max_gt> gt_mean[estimate_r && gt_mean_sd > 0]; // mean of generation time (if uncertain)
Expand All @@ -51,31 +51,36 @@ parameters{
real<lower = 0> frac_obs[obs_scale]; // fraction of cases that are ultimately observed
real truncation_mean[truncation]; // mean of truncation
real<lower = 0> truncation_sd[truncation]; // sd of truncation
real<lower = 0> rep_phi[model_type]; // overdispersion of the reporting process
real<lower = 0> rep_phi[obs_dist]; // overdispersion of the reporting process
}

transformed parameters {
vector[fixed ? 0 : noise_terms] noise; // noise generated by the gaussian process
vector<lower = 0, upper = 10 * r_mean>[estimate_r > 0 ? ot_h : 0] R; // reproduction number
vector[seeding_time] uobs_inf;
vector[t] infections; // latent infections
vector[ot_h] cov; // covaraites
vector[ot_h] reports; // estimated reported cases
vector[ot] obs_reports; // observed estimated reported cases
// GP in noise - spectral densities
if (!fixed) {
noise = update_gp(PHI, M, L, alpha[1], rho[1], eta, gp_type);
}
// update covariates
cov = update_covariate(base_cov, noise, breakpoints, bp_effects, stationary, ot_h, 0);
uobs_inf = generate_seed(initial_infections, initial_growth, seeding_time);
// Estimate latent infections
if (estimate_r) {
if (process_model == 0) {
// via deconvolution
infections = infection_model(cov, uobs_inf, future_time);
} else if (process_model == 1) {
// via growth
infections = growth_model(cov, uobs_inf, future_time);
} else if (process_model == 2) {
// via Rt
real set_gt_mean = (gt_mean_sd > 0 ? gt_mean[1] : gt_mean_mean);
real set_gt_sd = (gt_sd_sd > 0 ? gt_sd[1] : gt_sd_mean);
R = update_Rt(ot_h, log_R[estimate_r], noise, breakpoints, bp_effects, stationary);
infections = generate_infections(R, seeding_time, set_gt_mean, set_gt_sd, max_gt,
initial_infections, initial_growth,
pop, future_time);
} else {
// via deconvolution
infections = deconvolve_infections(shifted_cases, noise, fixed, backcalc_prior);
infections = renewal_model(cov, uobs_inf, set_gt_mean, set_gt_sd, max_gt,
pop, future_time);
}
// convolve from latent infections to mean of observations
reports = convolve_to_report(infections, delay_mean, delay_sd, max_delay, seeding_time);
Expand All @@ -102,49 +107,54 @@ model {
gaussian_process_lp(rho[1], alpha[1], eta, ls_meanlog,
ls_sdlog, ls_min, ls_max, alpha_sd);
}
// penalised priors for delay distributions
delays_lp(delay_mean, delay_mean_mean, delay_mean_sd, delay_sd, delay_sd_mean, delay_sd_sd, t);
if (delays > 0) {
// penalised priors for delay distributions
delays_lp(delay_mean, delay_mean_mean, delay_mean_sd, delay_sd, delay_sd_mean, delay_sd_sd, t);
}
// priors for truncation
truncation_lp(truncation_mean, truncation_sd, trunc_mean_mean, trunc_mean_sd,
trunc_sd_mean, trunc_sd_sd);
if (truncation) {
truncation_lp(truncation_mean, truncation_sd, trunc_mean_mean, trunc_mean_sd,
trunc_sd_mean, trunc_sd_sd);
}
if (estimate_r) {
// priors on Rt
rt_lp(log_R, initial_infections, initial_growth, bp_effects, bp_sd, bp_n, seeding_time,
r_logmean, r_logsd, prior_infections, prior_growth);
// penalised_prior on generation interval
generation_time_lp(gt_mean, gt_mean_mean, gt_mean_sd, gt_sd, gt_sd_mean, gt_sd_sd, ot);
}
// priors on Rt
covariate_lp(base_cov, bp_effects, bp_sd, bp_n, r_logmean, r_logsd);
infections_lp(initial_infections, initial_growth, prior_infections, prior_growth,
seeding_time);
// prior observation scaling
if (obs_scale) {
frac_obs[1] ~ normal(obs_scale_mean, obs_scale_sd) T[0,];
}
// observed reports from mean of reports (update likelihood)
if (likelihood) {
report_lp(cases, obs_reports, rep_phi, phi_mean, phi_sd,
model_type, obs_weight);
obs_dist, obs_weight);
}
}

generated quantities {
int imputed_reports[ot_h];
vector[estimate_r > 0 ? 0: ot_h] gen_R;
vector[estimate_r > 0 ? 0: ot_h] R;
real r[ot_h - 1];
vector[return_likelihood > 1 ? ot : 0] log_lik;
if (estimate_r == 0){
// sample generation time
real gt_mean_sample = (gt_mean_sd > 0 ? normal_rng(gt_mean_mean, gt_mean_sd) : gt_mean_mean);
real gt_sd_sample = (gt_sd_sd > 0 ? normal_rng(gt_sd_mean, gt_sd_sd) : gt_sd_mean);
// calculate Rt using infections and generation time
gen_R = calculate_Rt(infections, seeding_time, gt_mean_sample, gt_sd_sample,
R = calculate_Rt(infections, seeding_time, gt_mean_sample, gt_sd_sample,
max_gt, rt_half_window);
}
// estimate growth from infections
r = calculate_growth(infections, seeding_time + 1);
// simulate reported cases
imputed_reports = report_rng(reports, rep_phi, model_type);
imputed_reports = report_rng(reports, rep_phi, obs_dist);
// log likelihood of model
if (return_likelihood) {
log_lik = report_log_lik(cases, obs_reports, rep_phi, model_type,
log_lik = report_log_lik(cases, obs_reports, rep_phi, obs_dist,
obs_weight);
}
}
8 changes: 4 additions & 4 deletions inst/stan/estimate_secondary.stan
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ parameters{
real<lower = 0> frac_obs[obs_scale]; // fraction of cases that are ultimately observed
real truncation_mean[truncation]; // mean of truncation
real truncation_sd[truncation]; // sd of truncation
real<lower = 0> rep_phi[model_type]; // overdispersion of the reporting process
real<lower = 0> rep_phi[obs_dist]; // overdispersion of the reporting process
}

transformed parameters {
Expand Down Expand Up @@ -56,18 +56,18 @@ model {
// observed secondary reports from mean of secondary reports (update likelihood)
if (likelihood) {
report_lp(obs[(burn_in + 1):t], secondary[(burn_in + 1):t],
rep_phi, phi_mean, phi_sd, model_type, 1);
rep_phi, phi_mean, phi_sd, obs_dist, 1);
}
}

generated quantities {
int sim_secondary[t - burn_in];
vector[return_likelihood > 1 ? t - burn_in : 0] log_lik;
// simulate secondary reports
sim_secondary = report_rng(secondary[(burn_in + 1):t], rep_phi, model_type);
sim_secondary = report_rng(secondary[(burn_in + 1):t], rep_phi, obs_dist);
// log likelihood of model
if (return_likelihood) {
log_lik = report_log_lik(obs[(burn_in + 1):t], secondary[(burn_in + 1):t],
rep_phi, model_type, obs_weight);
rep_phi, obs_dist, obs_weight);
}
}
16 changes: 5 additions & 11 deletions inst/stan/functions/covariates.stan
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,15 @@ vector update_covariate(real base_cov, vector noise, int[] bps,
}
return(cov);
}
// Rt priors
void covariate_lp(vector base_cov, real[] initial_infections, real[] initial_growth,
real[] bp_effects, real[] bp_sd, int bp_n, int seeding_time,
real r_logmean, real r_logsd, real prior_infections,
real prior_growth) {

void covariate_lp(real base_cov,
real[] bp_effects, real[] bp_sd, int bp_n,
real r_logmean, real r_logsd) {
// initial prior
base_cov ~ normal(base_logmean, base_logsd);
base_cov ~ normal(r_logmean, r_logsd);
//breakpoint effects on Rt
if (bp_n > 0) {
bp_sd[1] ~ normal(0, 0.1) T[0,];
bp_effects ~ normal(0, bp_sd[1]);
}
// initial infections
initial_infections ~ normal(prior_infections, 0.2);
if (seeding_time > 1) {
initial_growth ~ normal(prior_growth, 0.2);
}
}
Loading

0 comments on commit 2406f38

Please sign in to comment.