From 93c072d9aad839ccda76fa75f74ebc9d403f4d26 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Fri, 22 Mar 2024 20:51:19 +0000 Subject: [PATCH] Directly calculated growth (#610) * switch to directly calculated growth rate * remove superseded function * add news item * avoid conversion * add skips * add test for exp(r)==R with fixed gt of 1 day * fix typo and add reviewer Co-authored-by: Sam Abbott --------- Co-authored-by: seabbs --- NEWS.md | 1 + R/extract.R | 2 +- inst/stan/estimate_infections.stan | 13 ++++------ inst/stan/functions/generated_quantities.stan | 24 +++++++------------ inst/stan/simulate_infections.stan | 10 ++++---- tests/testthat/setup.R | 2 +- tests/testthat/test-estimate_infections.R | 22 +++++++++++++++-- .../testthat/test-stan-generated_quantities.R | 11 +++++++++ 8 files changed, 50 insertions(+), 35 deletions(-) create mode 100644 tests/testthat/test-stan-generated_quantities.R diff --git a/NEWS.md b/NEWS.md index bb243bb5e..b3246c7a6 100644 --- a/NEWS.md +++ b/NEWS.md @@ -40,6 +40,7 @@ * Updated the parameterisation of the dispersion term `phi` to be `phi = 1 / sqrt_phi ^ 2` rather than the previous parameterisation `phi = 1 / sqrt(sqrt_phi)` based on the suggested prior [here](https://github.com/stan-dev/stan/wiki/Prior-Choice-Recommendations#story-when-the-generic-prior-fails-the-case-of-the-negative-binomial) and the performance benefits seen in the `epinowcast` package (see [here](https://github.com/epinowcast/epinowcast/blob/8eff560d1fd8305f5fb26c21324b2bfca1f002b4/inst/stan/epinowcast.stan#L314)). By @seabbs in #487 and reviewed by @sbfnk. * Added an `na` argument to `obs_opts()` that allows the user to specify whether NA values in the data should be interpreted as missing or accumulated in the next non-NA data point. By @sbfnk in #534 and reviewed by @seabbs. +* Growth rates are now calculated directly from the infection trajectory as `log I(t) - log I(t - 1)`. Originally by @seabbs in #213, finished by @sbfnk in #610 and reviewed by @seabbs. # EpiNow2 1.4.0 diff --git a/R/extract.R b/R/extract.R index 9b1f011e6..838003bce 100644 --- a/R/extract.R +++ b/R/extract.R @@ -207,7 +207,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates, out$growth_rate <- extract_parameter( "r", samples, - reported_dates + reported_dates[-1] ) if (data$week_effect > 1) { out$day_of_week <- extract_parameter( diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index 5da1ab3a8..81a6538d1 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -195,17 +195,12 @@ model { generated quantities { array[ot_h] int imputed_reports; vector[estimate_r > 0 ? 0: ot_h] gen_R; - array[ot_h] real r; + vector[ot_h - 1] r; real gt_mean; real gt_var; vector[return_likelihood ? ot : 0] log_lik; profile("generated quantities") { - if (estimate_r){ - // estimate growth from estimated Rt - gt_mean = rev_pmf_mean(gt_rev_pmf, 1); - gt_var = rev_pmf_var(gt_rev_pmf, 1, gt_mean); - r = R_to_growth(R, gt_mean, gt_var); - } else { + if (estimate_r == 0){ // sample generation time vector[delay_params_length] delay_params_sample = to_vector(normal_lb_rng( delay_params_mean, delay_params_sd, delay_params_lower @@ -222,9 +217,9 @@ generated quantities { gen_R = calculate_Rt( infections, seeding_time, sampled_gt_rev_pmf, rt_half_window ); - // estimate growth from calculated Rt - r = R_to_growth(gen_R, gt_mean, gt_var); } + // estimate growth from infections + r = calculate_growth(infections, seeding_time + 1); // simulate reported cases imputed_reports = report_rng(reports, rep_phi, model_type); // log likelihood of model diff --git a/inst/stan/functions/generated_quantities.stan b/inst/stan/functions/generated_quantities.stan index 39411a999..d418a9d7b 100644 --- a/inst/stan/functions/generated_quantities.stan +++ b/inst/stan/functions/generated_quantities.stan @@ -28,20 +28,12 @@ vector calculate_Rt(vector infections, int seeding_time, } return(sR); } -// Convert an estimate of Rt to growth -array[] real R_to_growth(vector R, real gt_mean, real gt_var) { - int t = num_elements(R); - array[t] real r; - if (gt_var > 0) { - real k = gt_var * inv_square(gt_mean); - for (s in 1:t) { - r[s] = (pow(R[s], k) - 1) / (k * gt_mean); - } - } else { - // limit as gt_sd -> 0 - for (s in 1:t) { - r[s] = log(R[s]) / gt_mean; - } - } - return(r); + +// Calculate growth rate +vector calculate_growth(vector infections, int seeding_time) { + int t = num_elements(infections); + int ot = t - seeding_time; + vector[t] log_inf = log(infections); + vector[ot] growth = log_inf[(seeding_time + 1):t] - log_inf[seeding_time:(t - 1)]; + return(growth); } diff --git a/inst/stan/simulate_infections.stan b/inst/stan/simulate_infections.stan index d6ff12128..1f4f65cb9 100644 --- a/inst/stan/simulate_infections.stan +++ b/inst/stan/simulate_infections.stan @@ -35,7 +35,7 @@ generated quantities { matrix[n, t] infections; //latent infections matrix[n, t - seeding_time] reports; // observed cases array[n, t - seeding_time] int imputed_reports; - array[n, t - seeding_time] real r; + matrix[n, t - seeding_time - 1] r; for (i in 1:n) { // generate infections from Rt trace vector[delay_type_max[gt_id] + 1] gt_rev_pmf; @@ -94,10 +94,8 @@ generated quantities { imputed_reports[i] = report_rng( to_vector(reports[i]), rep_phi[i], model_type ); - { - real gt_mean = rev_pmf_mean(gt_rev_pmf, 0); - real gt_var = rev_pmf_var(gt_rev_pmf, 0, gt_mean); - r[i] = R_to_growth(to_vector(R[i]), gt_mean, gt_var); - } + r[i] = to_row_vector( + calculate_growth(to_vector(infections[i]), seeding_time + 1) + ); } } diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 4b63e43af..a5d565846 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -4,7 +4,7 @@ library("lifecycle") if (identical(Sys.getenv("NOT_CRAN"), "true")) { files <- c( "convolve.stan", "pmfs.stan", "observation_model.stan", "secondary.stan", - "rt.stan", "infections.stan", "delays.stan" + "rt.stan", "infections.stan", "delays.stan", "generated_quantities.stan" ) if (!(tolower(Sys.info()[["sysname"]]) %in% "windows")) { suppressMessages( diff --git a/tests/testthat/test-estimate_infections.R b/tests/testthat/test-estimate_infections.R index 53234b0ff..64dafd376 100644 --- a/tests/testthat/test-estimate_infections.R +++ b/tests/testthat/test-estimate_infections.R @@ -4,7 +4,8 @@ futile.logger::flog.threshold("FATAL") reported_cases <- EpiNow2::example_confirmed[1:30] -default_estimate_infections <- function(..., add_stan = list(), delay = TRUE) { +default_estimate_infections <- function(..., add_stan = list(), gt = TRUE, + delay = TRUE) { futile.logger::flog.threshold("FATAL") def_stan <- list( @@ -15,7 +16,9 @@ default_estimate_infections <- function(..., add_stan = list(), delay = TRUE) { stan_args <- do.call(stan_opts, def_stan) suppressWarnings(estimate_infections(..., - generation_time = generation_time_opts(example_generation_time), + generation_time = fifelse( + gt, generation_time_opts(example_generation_time), generation_time_opts() + ), delays = ifelse(delay, list(delay_opts(example_reporting_delay)), list(delay_opts()))[[1]], stan = stan_args, verbose = FALSE )) @@ -27,6 +30,7 @@ test_estimate_infections <- function(...) { expect_true(nrow(out$samples) > 0) expect_true(nrow(out$summarised) > 0) expect_true(nrow(out$observations) > 0) + invisible(out) } # Test functionality ------------------------------------------------------ @@ -89,6 +93,20 @@ test_that("estimate_infections successfully returns estimates using a random wal test_estimate_infections(reported_cases, gp = NULL, rt = rt_opts(rw = 7)) }) +test_that("estimate_infections works without setting a generation time", { + skip_on_cran() + df <- test_estimate_infections(reported_cases, gt = FALSE, delay = FALSE) + ## check exp(r) == R + growth_rate <- df$samples[variable == "growth_rate"][, + list(date, sample, growth_rate = value) + ] + R <- df$samples[variable == "R"][, + list(date, sample, R = value) + ] + combined <- merge(growth_rate, R, by = c("date", "sample"), all = FALSE) + expect_equal(exp(combined$growth_rate), combined$R) +}) + test_that("estimate_infections fails as expected when given a very short timeout", { skip_on_cran() expect_error(output <- capture.output(suppressMessages( diff --git a/tests/testthat/test-stan-generated_quantities.R b/tests/testthat/test-stan-generated_quantities.R new file mode 100644 index 000000000..60e2285da --- /dev/null +++ b/tests/testthat/test-stan-generated_quantities.R @@ -0,0 +1,11 @@ +skip_on_cran() +skip_on_os("windows") + +test_that("calculate_growth works as expected", { + skip_on_cran() + expect_equal(calculate_growth(rep(1, 5), 1), rep(0, 4)) + expect_equal(round(calculate_growth(1:5, 2), 2), c(0.41, 0.29, 0.22)) + expect_equal(round(calculate_growth(exp(0.4*1:5), 2), 2), rep(0.4, 3)) + expect_error(calculate_growth(1:5, 6)) + expect_error(calculate_growth(1:5, 0)) +})