Skip to content

Commit

Permalink
reduce unnecessary function calls
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk committed Feb 13, 2023
1 parent 2efbd48 commit 9037e21
Showing 1 changed file with 65 additions and 47 deletions.
112 changes: 65 additions & 47 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,23 @@ transformed data{
vector[trunc_max_fixed] trunc_fixed_pmf;
vector[delay_max_fixed] delay_fixed_pmf;

gt_fixed_pmf = convolve_ragged_pmf(
gt_np_pmf, gt_np_pmf_groups, gt_max_fixed
);
if (gt_n) {
gt_fixed_pmf = convolve_ragged_pmf(
gt_np_pmf, gt_np_pmf_groups, gt_max_fixed
);
}

trunc_fixed_pmf = convolve_ragged_pmf(
trunc_np_pmf, trunc_np_pmf_groups, trunc_max_fixed
);
if (trunc_n) {
trunc_fixed_pmf = convolve_ragged_pmf(
trunc_np_pmf, trunc_np_pmf_groups, trunc_max_fixed
);
}

delay_fixed_pmf = convolve_ragged_pmf(
delay_np_pmf, delay_np_pmf_groups, delay_max_fixed
);
if (delay_n) {
delay_fixed_pmf = convolve_ragged_pmf(
delay_np_pmf, delay_np_pmf_groups, delay_max_fixed
);
}
}

parameters{
Expand Down Expand Up @@ -95,9 +101,11 @@ transformed parameters {
// Estimate latent infections
if (estimate_r) {
// via Rt
gt_rev_pmf = combine_pmfs(
gt_fixed_pmf, gt_mean, gt_sd, gt_max, gt_dist, gt_max_total, 1, 1
);
if (gt_n) {
gt_rev_pmf = combine_pmfs(
gt_fixed_pmf, gt_mean, gt_sd, gt_max, gt_dist, gt_max_total, 1, 1
);
}
R = update_Rt(
ot_h, log_R[estimate_r], noise, breakpoints, bp_effects, stationary
);
Expand All @@ -111,30 +119,34 @@ transformed parameters {
shifted_cases, noise, fixed, backcalc_prior
);
}
// convolve from latent infections to mean of observations
{
if (delay_n) {
// convolve from latent infections to mean of observations
vector[delay_max_total] delay_rev_pmf;
delay_rev_pmf = combine_pmfs(
delay_fixed_pmf, delay_mean, delay_sd, delay_max, delay_dist, delay_max_total, 0, 1
);
reports = convolve_to_report(infections, delay_rev_pmf, seeding_time);
} else {
reports = infections[(seeding_time + 1):t];
}
if (week_effect > 1) {
// weekly reporting effect
reports = day_of_week_effect(reports, day_of_week, day_of_week_simplex);
}
if (obs_scale) {
// scaling of reported cases by fraction observed
reports = scale_obs(reports, frac_obs[1]);
}
if (trunc_n) {
// truncate near time cases to observed reports
vector[trunc_max_total] trunc_rev_cmf;
trunc_rev_cmf = reverse_mf(cumulative_sum(combine_pmfs(
trunc_fixed_pmf, trunc_mean, trunc_sd, trunc_max, trunc_dist, trunc_max_total, 0, 0
)));
obs_reports = truncate(reports[1:ot], trunc_rev_cmf, 0);
} else {
obs_reports = reports[1:ot];
}
// weekly reporting effect
if (week_effect > 1) {
reports = day_of_week_effect(reports, day_of_week, day_of_week_simplex);
}
// scaling of reported cases by fraction observed
if (obs_scale) {
reports = scale_obs(reports, frac_obs[1]);
}
// truncate near time cases to observed reports
{
vector[trunc_max_total] trunc_rev_cmf;
trunc_rev_cmf = reverse_mf(cumulative_sum(combine_pmfs(
trunc_fixed_pmf, trunc_mean, trunc_sd, trunc_max, trunc_dist, trunc_max_total, 0, 0
)));
obs_reports = truncate(reports[1:ot], trunc_rev_cmf, 0);
}
}

model {
Expand All @@ -144,30 +156,36 @@ model {
rho[1], alpha[1], eta, ls_meanlog, ls_sdlog, ls_min, ls_max, alpha_sd
);
}
// penalised priors for delay distributions
delays_lp(
delay_mean, delay_mean_mean,
delay_mean_sd, delay_sd, delay_sd_mean, delay_sd_sd,
delay_dist, delay_weight
);
// priors for truncation
delays_lp(
trunc_mean, trunc_sd,
trunc_mean_mean, trunc_mean_sd,
trunc_sd_mean, trunc_sd_sd,
trunc_dist, 1
);
if (delay_n_p) {
// penalised priors for delay distributions
delays_lp(
delay_mean, delay_mean_mean,
delay_mean_sd, delay_sd, delay_sd_mean, delay_sd_sd,
delay_dist, delay_weight
);
}
// priors for truncation
if (trunc_n_p) {
delays_lp(
trunc_mean, trunc_sd,
trunc_mean_mean, trunc_mean_sd,
trunc_sd_mean, trunc_sd_sd,
trunc_dist, 1
);
}
if (estimate_r) {
// priors on Rt
rt_lp(
log_R, initial_infections, initial_growth, bp_effects, bp_sd, bp_n,
seeding_time, r_logmean, r_logsd, prior_infections, prior_growth
);
// penalised_prior on generation interval
delays_lp(
gt_mean, gt_mean_mean, gt_mean_sd, gt_sd, gt_sd_mean, gt_sd_sd, gt_dist,
gt_weight
);
if (gt_n_p) {
delays_lp(
gt_mean, gt_mean_mean, gt_mean_sd, gt_sd, gt_sd_mean, gt_sd_sd, gt_dist,
gt_weight
);
}
}
// prior observation scaling
if (obs_scale) {
Expand Down

0 comments on commit 9037e21

Please sign in to comment.