Skip to content

Commit

Permalink
add profiling
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk committed Mar 15, 2024
1 parent 19b5707 commit 985f997
Showing 1 changed file with 134 additions and 93 deletions.
227 changes: 134 additions & 93 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,13 @@ transformed data{
real r_logmean = log(r_mean^2 / sqrt(r_sd^2 + r_mean^2));
real r_logsd = sqrt(log(1 + (r_sd^2 / r_mean^2)));

array[delay_types] int delay_type_max = get_delay_type_max(
delay_types, delay_types_p, delay_types_id,
delay_types_groups, delay_max, delay_np_pmf_groups
);
array[delay_types] int delay_type_max;
profile("assign max") {
delay_type_max = get_delay_type_max(
delay_types, delay_types_p, delay_types_id,
delay_types_groups, delay_max, delay_np_pmf_groups
);
}
}

parameters{
Expand Down Expand Up @@ -63,93 +66,129 @@ transformed parameters {
vector[ot] obs_reports; // observed estimated reported cases
vector[estimate_r * (delay_type_max[gt_id] + 1)] gt_rev_pmf;
// GP in noise - spectral densities
if (!fixed) {
noise = update_gp(PHI, M, L, alpha[1], rho[1], eta, gp_type);
profile("update gp") {
if (!fixed) {
noise = update_gp(PHI, M, L, alpha[1], rho[1], eta, gp_type);
}
}
// Estimate latent infections
if (estimate_r) {
gt_rev_pmf = get_delay_rev_pmf(
gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id,
delay_types_groups, delay_max, delay_np_pmf,
delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist,
1, 1, 0
);
R = update_Rt(
ot_h, log_R[estimate_r], noise, breakpoints, bp_effects, stationary
);
infections = generate_infections(
R, seeding_time, gt_rev_pmf, initial_infections, initial_growth, pop,
future_time
);
profile("gt") {
gt_rev_pmf = get_delay_rev_pmf(
gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id,
delay_types_groups, delay_max, delay_np_pmf,
delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist,
1, 1, 0
);
}
profile("R") {
R = update_Rt(
ot_h, log_R[estimate_r], noise, breakpoints, bp_effects, stationary
);
}
profile("infections") {
infections = generate_infections(
R, seeding_time, gt_rev_pmf, initial_infections, initial_growth, pop,
future_time
);
}
} else {
// via deconvolution
infections = deconvolve_infections(
shifted_cases, noise, fixed, backcalc_prior
);
profile("infections") {
infections = deconvolve_infections(
shifted_cases, noise, fixed, backcalc_prior
);
}
}
// convolve from latent infections to mean of observations
if (delay_id) {
vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf(
delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id,
delay_types_groups, delay_max, delay_np_pmf,
delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist,
0, 1, 0
);
reports = convolve_to_report(infections, delay_rev_pmf, seeding_time);
vector[delay_type_max[delay_id] + 1] delay_rev_pmf;
profile("delays") {
delay_rev_pmf = get_delay_rev_pmf(
delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id,
delay_types_groups, delay_max, delay_np_pmf,
delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist,
0, 1, 0
);
}
profile("reports") {
reports = convolve_to_report(infections, delay_rev_pmf, seeding_time);
}
} else {
reports = infections[(seeding_time + 1):t];
}
// weekly reporting effect
if (week_effect > 1) {
reports = day_of_week_effect(reports, day_of_week, day_of_week_simplex);
profile("day of the week") {
reports = day_of_week_effect(reports, day_of_week, day_of_week_simplex);
}
}
// scaling of reported cases by fraction observed
if (obs_scale) {
reports = scale_obs(reports, obs_scale_sd > 0 ? frac_obs[1] : obs_scale_mean);
}
// truncate near time cases to observed reports
if (trunc_id) {
vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf(
trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id,
delay_types_groups, delay_max, delay_np_pmf,
delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist,
0, 1, 1
);
obs_reports = truncate(reports[1:ot], trunc_rev_cmf, 0);
} else {
obs_reports = reports[1:ot];
}
if (obs_scale) {
profile("scale") {
reports = scale_obs(
reports, obs_scale_sd > 0 ? frac_obs[1] : obs_scale_mean
);
}
}
// truncate near time cases to observed reports
if (trunc_id) {
vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf;
profile("truncation") {
trunc_rev_cmf = get_delay_rev_pmf(
trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id,
delay_types_groups, delay_max, delay_np_pmf,
delay_np_pmf_groups, delay_params, delay_params_groups, delay_dist,
0, 1, 1
);
}
profile("truncate") {
obs_reports = truncate(reports[1:ot], trunc_rev_cmf, 0);
}
} else {
obs_reports = reports[1:ot];
}
}

model {
// priors for noise GP
if (!fixed) {
gaussian_process_lp(
rho[1], alpha[1], eta, ls_meanlog, ls_sdlog, ls_min, ls_max, alpha_sd
);
profile("gp lp") {
gaussian_process_lp(
rho[1], alpha[1], eta, ls_meanlog, ls_sdlog, ls_min, ls_max, alpha_sd
);
}
}
// penalised priors for delay distributions
delays_lp(
delay_params, delay_params_mean, delay_params_sd, delay_params_groups,
delay_dist, delay_weight
);
profile("delays lp") {
delays_lp(
delay_params, delay_params_mean, delay_params_sd, delay_params_groups,
delay_dist, delay_weight
);
}
if (estimate_r) {
// priors on Rt
rt_lp(
log_R, initial_infections, initial_growth, bp_effects, bp_sd, bp_n,
seeding_time, r_logmean, r_logsd, prior_infections, prior_growth
);
profile("rt lp") {
rt_lp(
log_R, initial_infections, initial_growth, bp_effects, bp_sd, bp_n,
seeding_time, r_logmean, r_logsd, prior_infections, prior_growth
);
}
}
// prior observation scaling
if (obs_scale_sd > 0) {
frac_obs[1] ~ normal(obs_scale_mean, obs_scale_sd) T[0, 1];
profile("scale lp") {
frac_obs[1] ~ normal(obs_scale_mean, obs_scale_sd) T[0, 1];
}
}
// observed reports from mean of reports (update likelihood)
if (likelihood) {
report_lp(
cases, cases_time, obs_reports, rep_phi, phi_mean, phi_sd, model_type,
obs_weight, accumulate
);
profile("report lp") {
report_lp(
cases, cases_time, obs_reports, rep_phi, phi_mean, phi_sd, model_type,
obs_weight, accumulate
);
}
}
}

Expand All @@ -160,37 +199,39 @@ generated quantities {
real gt_mean;
real gt_var;
vector[return_likelihood ? ot : 0] log_lik;
if (estimate_r){
// estimate growth from estimated Rt
gt_mean = rev_pmf_mean(gt_rev_pmf, 1);
gt_var = rev_pmf_var(gt_rev_pmf, 1, gt_mean);
r = R_to_growth(R, gt_mean, gt_var);
} else {
// sample generation time
vector[delay_params_length] delay_params_sample = to_vector(normal_lb_rng(
delay_params_mean, delay_params_sd, delay_params_lower
));
vector[delay_type_max[gt_id] + 1] sampled_gt_rev_pmf = get_delay_rev_pmf(
gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id,
delay_types_groups, delay_max, delay_np_pmf,
delay_np_pmf_groups, delay_params_sample, delay_params_groups,
delay_dist, 1, 1, 0
);
gt_mean = rev_pmf_mean(sampled_gt_rev_pmf, 1);
gt_var = rev_pmf_var(sampled_gt_rev_pmf, 1, gt_mean);
// calculate Rt using infections and generation time
gen_R = calculate_Rt(
infections, seeding_time, sampled_gt_rev_pmf, rt_half_window
);
// estimate growth from calculated Rt
r = R_to_growth(gen_R, gt_mean, gt_var);
}
// simulate reported cases
imputed_reports = report_rng(reports, rep_phi, model_type);
// log likelihood of model
if (return_likelihood) {
log_lik = report_log_lik(
cases, obs_reports[cases_time], rep_phi, model_type, obs_weight
);
profile("generated quantities") {
if (estimate_r){
// estimate growth from estimated Rt
gt_mean = rev_pmf_mean(gt_rev_pmf, 1);
gt_var = rev_pmf_var(gt_rev_pmf, 1, gt_mean);
r = R_to_growth(R, gt_mean, gt_var);
} else {
// sample generation time
vector[delay_params_length] delay_params_sample = to_vector(normal_lb_rng(
delay_params_mean, delay_params_sd, delay_params_lower
));
vector[delay_type_max[gt_id] + 1] sampled_gt_rev_pmf = get_delay_rev_pmf(
gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id,
delay_types_groups, delay_max, delay_np_pmf,
delay_np_pmf_groups, delay_params_sample, delay_params_groups,
delay_dist, 1, 1, 0
);
gt_mean = rev_pmf_mean(sampled_gt_rev_pmf, 1);
gt_var = rev_pmf_var(sampled_gt_rev_pmf, 1, gt_mean);
// calculate Rt using infections and generation time
gen_R = calculate_Rt(
infections, seeding_time, sampled_gt_rev_pmf, rt_half_window
);
// estimate growth from calculated Rt
r = R_to_growth(gen_R, gt_mean, gt_var);
}
// simulate reported cases
imputed_reports = report_rng(reports, rep_phi, model_type);
// log likelihood of model
if (return_likelihood) {
log_lik = report_log_lik(
cases, obs_reports[cases_time], rep_phi, model_type, obs_weight
);
}
}
}

0 comments on commit 985f997

Please sign in to comment.