Skip to content

Commit

Permalink
Directly calculated growth (#610)
Browse files Browse the repository at this point in the history
* switch to directly calculated growth rate

* remove superseded function

* add news item

* avoid conversion

* add skips

* add test for exp(r)==R with fixed gt of 1 day

* fix typo and add reviewer

Co-authored-by: Sam Abbott <[email protected]>

---------

Co-authored-by: seabbs <[email protected]>
  • Loading branch information
sbfnk and seabbs authored Mar 22, 2024
1 parent 4856f26 commit 93c072d
Show file tree
Hide file tree
Showing 8 changed files with 50 additions and 35 deletions.
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@

* Updated the parameterisation of the dispersion term `phi` to be `phi = 1 / sqrt_phi ^ 2` rather than the previous parameterisation `phi = 1 / sqrt(sqrt_phi)` based on the suggested prior [here](https://github.com/stan-dev/stan/wiki/Prior-Choice-Recommendations#story-when-the-generic-prior-fails-the-case-of-the-negative-binomial) and the performance benefits seen in the `epinowcast` package (see [here](https://github.com/epinowcast/epinowcast/blob/8eff560d1fd8305f5fb26c21324b2bfca1f002b4/inst/stan/epinowcast.stan#L314)). By @seabbs in #487 and reviewed by @sbfnk.
* Added an `na` argument to `obs_opts()` that allows the user to specify whether NA values in the data should be interpreted as missing or accumulated in the next non-NA data point. By @sbfnk in #534 and reviewed by @seabbs.
* Growth rates are now calculated directly from the infection trajectory as `log I(t) - log I(t - 1)`. Originally by @seabbs in #213, finished by @sbfnk in #610 and reviewed by @seabbs.

# EpiNow2 1.4.0

Expand Down
2 changes: 1 addition & 1 deletion R/extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates,
out$growth_rate <- extract_parameter(
"r",
samples,
reported_dates
reported_dates[-1]
)
if (data$week_effect > 1) {
out$day_of_week <- extract_parameter(
Expand Down
13 changes: 4 additions & 9 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -195,17 +195,12 @@ model {
generated quantities {
array[ot_h] int imputed_reports;
vector[estimate_r > 0 ? 0: ot_h] gen_R;
array[ot_h] real r;
vector[ot_h - 1] r;
real gt_mean;
real gt_var;
vector[return_likelihood ? ot : 0] log_lik;
profile("generated quantities") {
if (estimate_r){
// estimate growth from estimated Rt
gt_mean = rev_pmf_mean(gt_rev_pmf, 1);
gt_var = rev_pmf_var(gt_rev_pmf, 1, gt_mean);
r = R_to_growth(R, gt_mean, gt_var);
} else {
if (estimate_r == 0){
// sample generation time
vector[delay_params_length] delay_params_sample = to_vector(normal_lb_rng(
delay_params_mean, delay_params_sd, delay_params_lower
Expand All @@ -222,9 +217,9 @@ generated quantities {
gen_R = calculate_Rt(
infections, seeding_time, sampled_gt_rev_pmf, rt_half_window
);
// estimate growth from calculated Rt
r = R_to_growth(gen_R, gt_mean, gt_var);
}
// estimate growth from infections
r = calculate_growth(infections, seeding_time + 1);
// simulate reported cases
imputed_reports = report_rng(reports, rep_phi, model_type);
// log likelihood of model
Expand Down
24 changes: 8 additions & 16 deletions inst/stan/functions/generated_quantities.stan
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,12 @@ vector calculate_Rt(vector infections, int seeding_time,
}
return(sR);
}
// Convert an estimate of Rt to growth
array[] real R_to_growth(vector R, real gt_mean, real gt_var) {
int t = num_elements(R);
array[t] real r;
if (gt_var > 0) {
real k = gt_var * inv_square(gt_mean);
for (s in 1:t) {
r[s] = (pow(R[s], k) - 1) / (k * gt_mean);
}
} else {
// limit as gt_sd -> 0
for (s in 1:t) {
r[s] = log(R[s]) / gt_mean;
}
}
return(r);

// Calculate growth rate
vector calculate_growth(vector infections, int seeding_time) {
int t = num_elements(infections);
int ot = t - seeding_time;
vector[t] log_inf = log(infections);
vector[ot] growth = log_inf[(seeding_time + 1):t] - log_inf[seeding_time:(t - 1)];
return(growth);
}
10 changes: 4 additions & 6 deletions inst/stan/simulate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ generated quantities {
matrix[n, t] infections; //latent infections
matrix[n, t - seeding_time] reports; // observed cases
array[n, t - seeding_time] int imputed_reports;
array[n, t - seeding_time] real r;
matrix[n, t - seeding_time - 1] r;
for (i in 1:n) {
// generate infections from Rt trace
vector[delay_type_max[gt_id] + 1] gt_rev_pmf;
Expand Down Expand Up @@ -94,10 +94,8 @@ generated quantities {
imputed_reports[i] = report_rng(
to_vector(reports[i]), rep_phi[i], model_type
);
{
real gt_mean = rev_pmf_mean(gt_rev_pmf, 0);
real gt_var = rev_pmf_var(gt_rev_pmf, 0, gt_mean);
r[i] = R_to_growth(to_vector(R[i]), gt_mean, gt_var);
}
r[i] = to_row_vector(
calculate_growth(to_vector(infections[i]), seeding_time + 1)
);
}
}
2 changes: 1 addition & 1 deletion tests/testthat/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ library("lifecycle")
if (identical(Sys.getenv("NOT_CRAN"), "true")) {
files <- c(
"convolve.stan", "pmfs.stan", "observation_model.stan", "secondary.stan",
"rt.stan", "infections.stan", "delays.stan"
"rt.stan", "infections.stan", "delays.stan", "generated_quantities.stan"
)
if (!(tolower(Sys.info()[["sysname"]]) %in% "windows")) {
suppressMessages(
Expand Down
22 changes: 20 additions & 2 deletions tests/testthat/test-estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ futile.logger::flog.threshold("FATAL")

reported_cases <- EpiNow2::example_confirmed[1:30]

default_estimate_infections <- function(..., add_stan = list(), delay = TRUE) {
default_estimate_infections <- function(..., add_stan = list(), gt = TRUE,
delay = TRUE) {
futile.logger::flog.threshold("FATAL")

def_stan <- list(
Expand All @@ -15,7 +16,9 @@ default_estimate_infections <- function(..., add_stan = list(), delay = TRUE) {
stan_args <- do.call(stan_opts, def_stan)

suppressWarnings(estimate_infections(...,
generation_time = generation_time_opts(example_generation_time),
generation_time = fifelse(
gt, generation_time_opts(example_generation_time), generation_time_opts()
),
delays = ifelse(delay, list(delay_opts(example_reporting_delay)), list(delay_opts()))[[1]],
stan = stan_args, verbose = FALSE
))
Expand All @@ -27,6 +30,7 @@ test_estimate_infections <- function(...) {
expect_true(nrow(out$samples) > 0)
expect_true(nrow(out$summarised) > 0)
expect_true(nrow(out$observations) > 0)
invisible(out)
}

# Test functionality ------------------------------------------------------
Expand Down Expand Up @@ -89,6 +93,20 @@ test_that("estimate_infections successfully returns estimates using a random wal
test_estimate_infections(reported_cases, gp = NULL, rt = rt_opts(rw = 7))
})

test_that("estimate_infections works without setting a generation time", {
skip_on_cran()
df <- test_estimate_infections(reported_cases, gt = FALSE, delay = FALSE)
## check exp(r) == R
growth_rate <- df$samples[variable == "growth_rate"][,
list(date, sample, growth_rate = value)
]
R <- df$samples[variable == "R"][,
list(date, sample, R = value)
]
combined <- merge(growth_rate, R, by = c("date", "sample"), all = FALSE)
expect_equal(exp(combined$growth_rate), combined$R)
})

test_that("estimate_infections fails as expected when given a very short timeout", {
skip_on_cran()
expect_error(output <- capture.output(suppressMessages(
Expand Down
11 changes: 11 additions & 0 deletions tests/testthat/test-stan-generated_quantities.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
skip_on_cran()
skip_on_os("windows")

test_that("calculate_growth works as expected", {
skip_on_cran()
expect_equal(calculate_growth(rep(1, 5), 1), rep(0, 4))
expect_equal(round(calculate_growth(1:5, 2), 2), c(0.41, 0.29, 0.22))
expect_equal(round(calculate_growth(exp(0.4*1:5), 2), 2), rep(0.4, 3))
expect_error(calculate_growth(1:5, 6))
expect_error(calculate_growth(1:5, 0))
})

0 comments on commit 93c072d

Please sign in to comment.