Skip to content

Commit

Permalink
check fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk committed Nov 18, 2022
1 parent d17f1b1 commit 40709d2
Show file tree
Hide file tree
Showing 8 changed files with 51 additions and 62 deletions.
5 changes: 3 additions & 2 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ generation_time_opts <- function(..., disease, source, max = 15L, fixed = FALSE,
names(gt) <- paste0("gt_", names(gt))

return(gt)
}

#' Delay Distribution Options
#'
Expand Down Expand Up @@ -300,7 +301,7 @@ process_opts <- function(model = "R",
rw = 0,
use_breakpoints = TRUE,
future = "latest",
stationary = FALSE
stationary = FALSE,
pop = 0) {

## check
Expand All @@ -319,7 +320,7 @@ process_opts <- function(model = "R",
use_breakpoints = use_breakpoints,
future = future,
stationary = stationary,
pop = pop,
pop = pop
)

# replace default settings with those specified by user
Expand Down
3 changes: 2 additions & 1 deletion inst/stan/data/covariates.stan
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
int process_model; // 0 = infections; 1 = growth; 2 = rt
int bp_n; // no of breakpoints (0 = no breakpoints)
int breakpoints[t - seeding_time]; // when do breakpoints occur
int cov_mean_const; // 0 = not const mean; 1 = const mean
real<lower = 0> cov_mean_mean[cov_mean_const]; // const covariate mean
real<lower = 0> cov_mean_sd[cov_mean_const]; // const covariate sd
real<lower = 0> cov_t[cov_mean_const ? 0 : t] // time-varying covariate mean
vector<lower = 0>[cov_mean_const ? 0 : t] cov_t; // time-varying covariate mean
2 changes: 1 addition & 1 deletion inst/stan/data/simulation_rt.stan
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
real<lower = 0> gt_sd[n, 1]; // sd of generation time
int<lower = 1> gt_max[1]; // maximum generation time
int gt_dist[1]; // 0 = lognormal; 1 = gamma
matrix[n, t - seeding_time] R; // reproduction number
vector[n] R[t - seeding_time]; // reproduction number
int pop; // susceptible population
42 changes: 25 additions & 17 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ transformed data{
int noise_terms = setup_noise(ot_h, t, horizon, estimate_r, stationary, future_fixed, fixed_from);
matrix[noise_terms, M] PHI = setup_gp(M, L, noise_terms); // basis function
// covariate mean
real cov_mean_logmean[cov_mean_const] = log(cov_mean^2 / sqrt(cov_sd^2 + cov_mean^2));
real cov_mean_logsd[cov_mean_const] = sqrt(log(1 + (cov_sd^2 / cov_mean^2)));
real cov_mean_logmean[cov_mean_const];
real cov_mean_logsd[cov_mean_const];

int delay_max_fixed = (n_fixed_delays == 0 ? 0 :
sum(delay_max[fixed_delays]) - num_elements(fixed_delays) + 1);
Expand All @@ -39,6 +39,11 @@ transformed data{
vector[truncation && trunc_fixed[1] ? trunc_max[1] : 0] trunc_fixed_pmf;
vector[delay_max_fixed] delays_fixed_pmf;

if (cov_mean_const) {
cov_mean_logmean[1] = log(cov_mean_mean[1]^2 / sqrt(cov_mean_sd[1]^2 + cov_mean_mean[1]^2));
cov_mean_logsd[1] = sqrt(log(1 + (cov_mean_sd[1]^2 / cov_mean_mean[1]^2)));
}

if (gt_fixed[1]) {
gt_fixed_pmf = discretised_pmf(gt_mean_mean[1], gt_sd_mean[1], gt_max[1], gt_dist[1], 1);
}
Expand Down Expand Up @@ -94,8 +99,10 @@ transformed parameters {
noise = update_gp(PHI, M, L, alpha[1], rho[1], eta, gp_type);
}
// update covariates
cov = update_covariate(log_cov_mean, cov_t, noise, breakpoints, bp_effects,
stationary, ot_h, 0);
cov = update_covariate(
log_cov_mean, cov_t, noise, breakpoints, bp_effects,
stationary, ot_h
);
uobs_inf = generate_seed(initial_infections, initial_growth, seeding_time);
// Estimate latent infections
if (process_model == 0) {
Expand Down Expand Up @@ -148,28 +155,29 @@ model {
}
if (delays > 0) {
// penalised priors for delay distributions
delays_lp(delay_mean, delay_mean_mean, delay_mean_sd, delay_sd, delay_sd_mean, delay_sd_sd, t);
delays_lp(
delay_mean, delay_mean_mean, delay_mean_sd, delay_sd, delay_sd_mean, delay_sd_sd, t
);
}
// priors for truncation
if (truncation) {
truncation_lp(truncation_mean, truncation_sd, trunc_mean_mean, trunc_mean_sd,
trunc_sd_mean, trunc_sd_sd);
truncation_lp(
trunc_mean, trunc_sd, trunc_mean_mean, trunc_mean_sd, trunc_sd_mean, trunc_sd_sd
);
}
if (estimate_r) {
// priors on Rt
rt_lp(
log_R, initial_infections, initial_growth, bp_effects, bp_sd, bp_n,
seeding_time, r_logmean, r_logsd, prior_infections, prior_growth
);
// penalised_prior on generation interval
generation_time_lp(
gt_mean, gt_mean_mean, gt_mean_sd, gt_sd, gt_sd_mean, gt_sd_sd, gt_weight
);
}
// priors on Rt
covariate_lp(log_cov_mean, bp_effects, bp_sd, bp_n, cov_mean_logmean, cov_mean_logsd);
infections_lp(initial_infections, initial_growth, prior_infections, prior_growth,
seeding_time);
// priors on covariates and infections
covariate_lp(
log_cov_mean, bp_effects, bp_sd, bp_n, cov_mean_logmean, cov_mean_logsd
);
infections_lp(
initial_infections, initial_growth, prior_infections, prior_growth, seeding_time
);
// prior observation scaling
if (obs_scale) {
frac_obs[1] ~ normal(obs_scale_mean, obs_scale_sd) T[0,];
Expand All @@ -185,7 +193,7 @@ model {
generated quantities {
int imputed_reports[ot_h];
vector[estimate_r > 0 ? ot_h : 0] R;
real r[ot_h];
vector[ot_h] r;
vector[return_likelihood > 1 ? ot : 0] log_lik;
if (estimate_r == 0 && process_model != 2) {
// sample generation time
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
// update combined covariates
vector update_covariate(vector log_cov_mean, vector cov_t,
vector update_covariate(real[] log_cov_mean, vector cov_t,
vector noise, int[] bps,
real[] bp_effects, int stationary,
int t) {
Expand All @@ -10,7 +10,7 @@ vector update_covariate(vector log_cov_mean, vector cov_t,
// define result vectors
vector[t] bp = rep_vector(0, t);
vector[t] gp = rep_vector(0, t);
vector[t] R;
vector[t] cov;
// initialise breakpoints
if (bp_n) {
for (s in 1:t) {
Expand All @@ -35,7 +35,7 @@ vector update_covariate(vector log_cov_mean, vector cov_t,
}
}
if (num_elements(log_cov_mean) > 0) {
cov = rep(log_cov_mean, t);
cov = rep_vector(log_cov_mean[1], t);
} else {
cov = log(cov_t);
}
Expand All @@ -47,10 +47,10 @@ vector update_covariate(vector log_cov_mean, vector cov_t,

void covariate_lp(real[] log_cov_mean,
real[] bp_effects, real[] bp_sd, int bp_n,
real cov_mean_logmean, real cov_mean_logsd) {
real[] cov_mean_logmean, real[] cov_mean_logsd) {
// initial prior
if (num_elements(log_cov_mean) > 0) {
log_cov_mean ~ normal(cov_mean_logmean, cov_mean_logsd);
log_cov_mean ~ normal(cov_mean_logmean[1], cov_mean_logsd[1]);
}
//breakpoint effects on Rt
if (bp_n > 0) {
Expand Down
22 changes: 3 additions & 19 deletions inst/stan/functions/generated_quantities.stan
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,12 @@ vector calculate_Rt(vector infections, int seeding_time,
}
return(R);
}
// Convert an estimate of Rt to growth
real[] R_to_growth(vector R, real gt_mean, real gt_sd) {
int t = num_elements(R);
real r[t];
if (gt_sd > 0) {
real k = pow(gt_sd / gt_mean, 2);
for (s in 1:t) {
r[s] = (pow(R[s], k) - 1) / (k * gt_mean);
}
} else {
// limit as gt_sd -> 0
for (s in 1:t) {
r[s] = log(R[s]) / gt_mean;
}
}
return(r);
}

// Calculate growth rate
real[] calculate_growth(vector infections, int seeding_time) {
vector calculate_growth(vector infections, int seeding_time) {
int t = num_elements(infections);
int ot = t - seeding_time;
vector[t] log_inf = log(infections);
vector[ot] growth = log_inf[(seeding_time + 1):t] - log_inf[seeding_time:(t - 1)];
return(to_array_1d(growth));
return(growth);
}
5 changes: 1 addition & 4 deletions inst/stan/functions/infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ vector generate_seed(real[] initial_infections, real[] initial_growth, int uot)
return(seed_infs);
}
// generate infections using infectiousness
vector renewal_model(vector oR, vector uobs_infs, vector gt_pmf,
vector renewal_model(vector oR, vector uobs_inf, vector gt_pmf,
int pop, int ht) {
// time indices and storage
int ot = num_elements(oR);
Expand All @@ -39,9 +39,6 @@ vector renewal_model(vector oR, vector uobs_infs, vector gt_pmf,
vector[t] infections = rep_vector(1e-5, t);
vector[ot] cum_infections = rep_vector(0, ot);
vector[ot] infectiousness = rep_vector(1e-5, ot);
// generation time pmf
vector[max_gt] gt_pmf = rep_vector(1e-5, max_gt);
int gt_indexes[max_gt];
// Initialise infections
infections[1:uot] = uobs_inf;
// calculate cumulative infections
Expand Down
24 changes: 11 additions & 13 deletions inst/stan/simulate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ transformed data {

generated quantities {
// generated quantities
matrix[n, t] infections; //latent infections
matrix[n, t - seeding_time] reports; // observed cases
vector[n] infections[t]; //latent infections
vector[n] reports[t - seeding_time]; // observed cases
int imputed_reports[n, t - seeding_time];
real r[n, t - seeding_time];
vector[n] r[t - seeding_time];
vector[seeding_time] uobs_inf;
for (i in 1:n) {
// generate infections from Rt trace
Expand All @@ -50,23 +50,21 @@ generated quantities {

uobs_inf = generate_seed(initial_infections[i], initial_growth[i], seeding_time);
// generate infections from Rt trace
infections[i] = to_row_vector(renewal_model(to_vector(R[i]), uobs_inf,
gt_mean[i, 1], gt_sd[i, 1], max_gt,
pop, future_time));
infections[i] = renewal_model(R[i], uobs_inf, gt_pmf, pop, future_time);
// convolve from latent infections to mean of observations
reports[i] = to_row_vector(convolve_to_report(to_vector(infections[i]), delay_pmf, seeding_time));
reports[i] = convolve_to_report(infections[i], delay_pmf, seeding_time);
// weekly reporting effect
if (week_effect > 1) {
reports[i] = to_row_vector(
day_of_week_effect(to_vector(reports[i]), day_of_week,
to_vector(day_of_week_simplex[i])));
reports[i] = day_of_week_effect(
reports[i], day_of_week, to_vector(day_of_week_simplex[i])
);
}
// scale observations
if (obs_scale) {
reports[i] = to_row_vector(scale_obs(to_vector(reports[i]), frac_obs[i, 1]));
reports[i] = scale_obs(reports[i], frac_obs[i, 1]);
}
// simulate reported cases
imputed_reports[i] = report_rng(to_vector(reports[i]), rep_phi[i], obs_dist);
r[i] = calculate_growth(to_vector(infections[i]), seeding_time);
imputed_reports[i] = report_rng(reports[i], rep_phi[i], obs_dist);
r[i] = calculate_growth(infections[i], seeding_time);
}
}

0 comments on commit 40709d2

Please sign in to comment.