Skip to content

Commit

Permalink
switch to directly calculated growth rate
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs authored and sbfnk committed Mar 12, 2024
1 parent 19b5707 commit 0eb758a
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 17 deletions.
2 changes: 1 addition & 1 deletion R/extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates,
out$growth_rate <- extract_parameter(
"r",
samples,
reported_dates
reported_dates[-1]
)
if (data$week_effect > 1) {
out$day_of_week <- extract_parameter(
Expand Down
13 changes: 4 additions & 9 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -156,16 +156,11 @@ model {
generated quantities {
array[ot_h] int imputed_reports;
vector[estimate_r > 0 ? 0: ot_h] gen_R;
array[ot_h] real r;
array[ot_h - 1] real r;
real gt_mean;
real gt_var;
vector[return_likelihood ? ot : 0] log_lik;
if (estimate_r){
// estimate growth from estimated Rt
gt_mean = rev_pmf_mean(gt_rev_pmf, 1);
gt_var = rev_pmf_var(gt_rev_pmf, 1, gt_mean);
r = R_to_growth(R, gt_mean, gt_var);
} else {
if (estimate_r == 0){
// sample generation time
vector[delay_params_length] delay_params_sample = to_vector(normal_lb_rng(
delay_params_mean, delay_params_sd, delay_params_lower
Expand All @@ -182,9 +177,9 @@ generated quantities {
gen_R = calculate_Rt(
infections, seeding_time, sampled_gt_rev_pmf, rt_half_window
);
// estimate growth from calculated Rt
r = R_to_growth(gen_R, gt_mean, gt_var);
}
// estimate growth from infections
r = calculate_growth(infections, seeding_time + 1);
// simulate reported cases
imputed_reports = report_rng(reports, rep_phi, model_type);
// log likelihood of model
Expand Down
9 changes: 9 additions & 0 deletions inst/stan/functions/generated_quantities.stan
Original file line number Diff line number Diff line change
Expand Up @@ -45,3 +45,12 @@ array[] real R_to_growth(vector R, real gt_mean, real gt_var) {
}
return(r);
}

// Calculate growth rate
array[] real calculate_growth(vector infections, int seeding_time) {
int t = num_elements(infections);
int ot = t - seeding_time;
vector[t] log_inf = log(infections);
vector[ot] growth = log_inf[(seeding_time + 1):t] - log_inf[seeding_time:(t - 1)];
return(to_array_1d(growth));
}
8 changes: 2 additions & 6 deletions inst/stan/simulate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ generated quantities {
matrix[n, t] infections; //latent infections
matrix[n, t - seeding_time] reports; // observed cases
array[n, t - seeding_time] int imputed_reports;
array[n, t - seeding_time] real r;
array[n, t - seeding_time - 1] real r;
for (i in 1:n) {
// generate infections from Rt trace
vector[delay_type_max[gt_id] + 1] gt_rev_pmf;
Expand Down Expand Up @@ -94,10 +94,6 @@ generated quantities {
imputed_reports[i] = report_rng(
to_vector(reports[i]), rep_phi[i], model_type
);
{
real gt_mean = rev_pmf_mean(gt_rev_pmf, 0);
real gt_var = rev_pmf_var(gt_rev_pmf, 0, gt_mean);
r[i] = R_to_growth(to_vector(R[i]), gt_mean, gt_var);
}
r[i] = calculate_growth(to_vector(infections[i]), seeding_time + 1);
}
}
2 changes: 1 addition & 1 deletion tests/testthat/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ library("lifecycle")
if (identical(Sys.getenv("NOT_CRAN"), "true")) {
files <- c(
"convolve.stan", "pmfs.stan", "observation_model.stan", "secondary.stan",
"rt.stan", "infections.stan", "delays.stan"
"rt.stan", "infections.stan", "delays.stan", "generated_quantities.stan"
)
if (!(tolower(Sys.info()[["sysname"]]) %in% "windows")) {
suppressMessages(
Expand Down
8 changes: 8 additions & 0 deletions tests/testthat/test-stan-generated_quantities.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
test_that("calculate_growth works as expected", {
skip_on_cran()
expect_equal(calculate_growth(rep(1, 5), 1), rep(0, 4))
expect_equal(round(calculate_growth(1:5, 2), 2), c(0.41, 0.29, 0.22))
expect_equal(round(calculate_growth(exp(0.4*1:5), 2), 2), rep(0.4, 3))
expect_error(calculate_growth(1:5, 6))
expect_error(calculate_growth(1:5, 0))
})

0 comments on commit 0eb758a

Please sign in to comment.