Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Directly calculated growth #610

Merged
merged 7 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

* Updated the parameterisation of the dispersion term `phi` to be `phi = 1 / sqrt_phi ^ 2` rather than the previous parameterisation `phi = 1 / sqrt(sqrt_phi)` based on the suggested prior [here](https://github.com/stan-dev/stan/wiki/Prior-Choice-Recommendations#story-when-the-generic-prior-fails-the-case-of-the-negative-binomial) and the performance benefits seen in the `epinowcast` package (see [here](https://github.com/epinowcast/epinowcast/blob/8eff560d1fd8305f5fb26c21324b2bfca1f002b4/inst/stan/epinowcast.stan#L314)). By @seabbs in #487 and reviewed by @sbfnk.
* Added an `na` argument to `obs_opts()` that allows the user to specify whether NA values in the data should be interpreted as missing or accumulated in the next non-NA data point. By @sbfnk in #534 and reviewed by @seabbs.
* Growth rates are now calculated directly from the infection trajectory as `log I(t) - log I(t - 1)`. Originally by @seabbs in #213, finished by @sbfnk in #610 and reviewed by @seabbs.

# EpiNow2 1.4.0

Expand Down
2 changes: 1 addition & 1 deletion R/extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates,
out$growth_rate <- extract_parameter(
"r",
samples,
reported_dates
reported_dates[-1]
)
if (data$week_effect > 1) {
out$day_of_week <- extract_parameter(
Expand Down
13 changes: 4 additions & 9 deletions inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -195,17 +195,12 @@ model {
generated quantities {
array[ot_h] int imputed_reports;
vector[estimate_r > 0 ? 0: ot_h] gen_R;
array[ot_h] real r;
vector[ot_h - 1] r;
real gt_mean;
real gt_var;
vector[return_likelihood ? ot : 0] log_lik;
profile("generated quantities") {
if (estimate_r){
// estimate growth from estimated Rt
gt_mean = rev_pmf_mean(gt_rev_pmf, 1);
gt_var = rev_pmf_var(gt_rev_pmf, 1, gt_mean);
r = R_to_growth(R, gt_mean, gt_var);
} else {
if (estimate_r == 0){
// sample generation time
vector[delay_params_length] delay_params_sample = to_vector(normal_lb_rng(
delay_params_mean, delay_params_sd, delay_params_lower
Expand All @@ -222,9 +217,9 @@ generated quantities {
gen_R = calculate_Rt(
infections, seeding_time, sampled_gt_rev_pmf, rt_half_window
);
// estimate growth from calculated Rt
r = R_to_growth(gen_R, gt_mean, gt_var);
}
// estimate growth from infections
r = calculate_growth(infections, seeding_time + 1);
// simulate reported cases
imputed_reports = report_rng(reports, rep_phi, model_type);
// log likelihood of model
Expand Down
24 changes: 8 additions & 16 deletions inst/stan/functions/generated_quantities.stan
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,12 @@ vector calculate_Rt(vector infections, int seeding_time,
}
return(sR);
}
// Convert an estimate of Rt to growth
array[] real R_to_growth(vector R, real gt_mean, real gt_var) {
int t = num_elements(R);
array[t] real r;
if (gt_var > 0) {
real k = gt_var * inv_square(gt_mean);
for (s in 1:t) {
r[s] = (pow(R[s], k) - 1) / (k * gt_mean);
}
} else {
// limit as gt_sd -> 0
for (s in 1:t) {
r[s] = log(R[s]) / gt_mean;
}
}
return(r);

// Calculate growth rate
vector calculate_growth(vector infections, int seeding_time) {
int t = num_elements(infections);
int ot = t - seeding_time;
vector[t] log_inf = log(infections);
vector[ot] growth = log_inf[(seeding_time + 1):t] - log_inf[seeding_time:(t - 1)];
return(growth);
}
10 changes: 4 additions & 6 deletions inst/stan/simulate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ generated quantities {
matrix[n, t] infections; //latent infections
matrix[n, t - seeding_time] reports; // observed cases
array[n, t - seeding_time] int imputed_reports;
array[n, t - seeding_time] real r;
matrix[n, t - seeding_time - 1] r;
for (i in 1:n) {
// generate infections from Rt trace
vector[delay_type_max[gt_id] + 1] gt_rev_pmf;
Expand Down Expand Up @@ -94,10 +94,8 @@ generated quantities {
imputed_reports[i] = report_rng(
to_vector(reports[i]), rep_phi[i], model_type
);
{
real gt_mean = rev_pmf_mean(gt_rev_pmf, 0);
real gt_var = rev_pmf_var(gt_rev_pmf, 0, gt_mean);
r[i] = R_to_growth(to_vector(R[i]), gt_mean, gt_var);
}
r[i] = to_row_vector(
calculate_growth(to_vector(infections[i]), seeding_time + 1)
);
}
}
2 changes: 1 addition & 1 deletion tests/testthat/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ library("lifecycle")
if (identical(Sys.getenv("NOT_CRAN"), "true")) {
files <- c(
"convolve.stan", "pmfs.stan", "observation_model.stan", "secondary.stan",
"rt.stan", "infections.stan", "delays.stan"
"rt.stan", "infections.stan", "delays.stan", "generated_quantities.stan"
)
if (!(tolower(Sys.info()[["sysname"]]) %in% "windows")) {
suppressMessages(
Expand Down
22 changes: 20 additions & 2 deletions tests/testthat/test-estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@ futile.logger::flog.threshold("FATAL")

reported_cases <- EpiNow2::example_confirmed[1:30]

default_estimate_infections <- function(..., add_stan = list(), delay = TRUE) {
default_estimate_infections <- function(..., add_stan = list(), gt = TRUE,
delay = TRUE) {
futile.logger::flog.threshold("FATAL")

def_stan <- list(
Expand All @@ -15,7 +16,9 @@ default_estimate_infections <- function(..., add_stan = list(), delay = TRUE) {
stan_args <- do.call(stan_opts, def_stan)

suppressWarnings(estimate_infections(...,
generation_time = generation_time_opts(example_generation_time),
generation_time = fifelse(
gt, generation_time_opts(example_generation_time), generation_time_opts()
),
delays = ifelse(delay, list(delay_opts(example_reporting_delay)), list(delay_opts()))[[1]],
stan = stan_args, verbose = FALSE
))
Expand All @@ -27,6 +30,7 @@ test_estimate_infections <- function(...) {
expect_true(nrow(out$samples) > 0)
expect_true(nrow(out$summarised) > 0)
expect_true(nrow(out$observations) > 0)
invisible(out)
}

# Test functionality ------------------------------------------------------
Expand Down Expand Up @@ -89,6 +93,20 @@ test_that("estimate_infections successfully returns estimates using a random wal
test_estimate_infections(reported_cases, gp = NULL, rt = rt_opts(rw = 7))
})

test_that("estimate_infections works without setting a generation time", {
skip_on_cran()
df <- test_estimate_infections(reported_cases, gt = FALSE, delay = FALSE)
## check exp(r) == R
growth_rate <- df$samples[variable == "growth_rate"][,
list(date, sample, growth_rate = value)
]
R <- df$samples[variable == "R"][,
list(date, sample, R = value)
]
combined <- merge(growth_rate, R, by = c("date", "sample"), all = FALSE)
expect_equal(exp(combined$growth_rate), combined$R)
})

test_that("estimate_infections fails as expected when given a very short timeout", {
skip_on_cran()
expect_error(output <- capture.output(suppressMessages(
Expand Down
11 changes: 11 additions & 0 deletions tests/testthat/test-stan-generated_quantities.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
skip_on_cran()
skip_on_os("windows")

test_that("calculate_growth works as expected", {
skip_on_cran()
expect_equal(calculate_growth(rep(1, 5), 1), rep(0, 4))
expect_equal(round(calculate_growth(1:5, 2), 2), c(0.41, 0.29, 0.22))
expect_equal(round(calculate_growth(exp(0.4*1:5), 2), 2), rep(0.4, 3))
expect_error(calculate_growth(1:5, 6))
expect_error(calculate_growth(1:5, 0))
})
Loading