diff --git a/NAMESPACE b/NAMESPACE index 358611087..633306f14 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -117,6 +117,7 @@ importFrom(checkmate,assert_names) importFrom(checkmate,assert_numeric) importFrom(checkmate,assert_path_for_output) importFrom(checkmate,assert_string) +importFrom(checkmate,assert_subset) importFrom(checkmate,test_data_frame) importFrom(checkmate,test_numeric) importFrom(data.table,":=") diff --git a/NEWS.md b/NEWS.md index bdbcd6b6d..59c93fc08 100644 --- a/NEWS.md +++ b/NEWS.md @@ -7,6 +7,7 @@ * The `fixed` argument to `dist_spec` has been deprecated and replaced by a `fix_dist()` function. By @sbfnk in #503 and reviewed by @seabbs. * Updated `estimate_infections()` so that rather than imputing missing data, it now skips these data points in the likelihood. This is a breaking change as it alters the behaviour of the model when dates are missing from a time series but are known to be zero. We recommend that users check their results when updating to this version but expect this to in most cases improve performance. By @seabbs in #528 and reviewed by @sbfnk. * `simulate_infections` has been renamed to `forecast_infections` in line with `simulate_secondary` and `forecast_secondary`. The terminology is: a forecast is done from a fit to existing data, a simulation from first principles. By @sbfnk in #544 and reviewed by @seabbs. +* A new `simulate_infections` function has been added that can be used to simulate from the model from given initial conditions and parameters. By @sbfnk in #557 and reviewed by @jamesmbaazam. ## Documentation diff --git a/R/extract.R b/R/extract.R index 8d4cb1232..ae2e18e0f 100644 --- a/R/extract.R +++ b/R/extract.R @@ -181,28 +181,31 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates, samples, reported_dates ) - if (data$estimate_r == 1) { - out$R <- extract_parameter( - "R", - samples, - reported_dates - ) - if (data$bp_n > 0) { - out$breakpoints <- extract_parameter( - "bp_effects", + if ("estimate_r" %in% names(data)) { + if (data$estimate_r == 1) { + out$R <- extract_parameter( + "R", samples, - 1:data$bp_n + reported_dates + ) + if (data$bp_n > 0) { + out$breakpoints <- extract_parameter( + "bp_effects", + samples, + 1:data$bp_n + ) + out$breakpoints <- out$breakpoints[ + , + strat := date + ][, c("time", "date") := NULL] + } + } else { + out$R <- extract_parameter( + "gen_R", + samples, + reported_dates ) - out$breakpoints <- out$breakpoints[, - strat := date][, c("time", "date") := NULL - ] } - } else { - out$R <- extract_parameter( - "gen_R", - samples, - reported_dates - ) } out$growth_rate <- extract_parameter( "r", @@ -243,7 +246,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates, value.V1 := NULL ] } - if (data$obs_scale_sd > 0) { + if ("obs_scale_sd" %in% names(data) && data$obs_scale_sd > 0) { out$fraction_observed <- extract_static_parameter("frac_obs", samples) out$fraction_observed <- out$fraction_observed[, value := value.V1][, value.V1 := NULL diff --git a/R/simulate_infections.R b/R/simulate_infections.R index 6e98de85b..61d474a47 100644 --- a/R/simulate_infections.R +++ b/R/simulate_infections.R @@ -1,19 +1,208 @@ -#' Deprecated; use [forecast_infections()] instead +#' Simulate infections using the renewal equation #' -#' Calling this function passes all arguments to [forecast_infections()] -#' @description `r lifecycle::badge("deprecated")` -#' @param ... Arguments to be passed to [forecast_infections()] -#' @return the result of [forecast_infections()] +#' Simulations are done from given initial infections and, potentially +#' time-varying, reproduction numbers. Delays and parameters of the observation +#' model can be specified using the same options as in [estimate_infections()]. +#' +#' In order to simulate, all parameters that are specified such as the mean and +#' standard deviation of delays or observation scaling, must be fixed. +#' Uncertain parameters are not allowed. +#' +#' A previous function called [simulate_infections()] that simulates from a +#' given model fit has been renamed [forecast_infections()]. Using +#' [simulate_infections()] with existing estimates is now deprecated. This +#' option will be removed in version 2.1.0. +#' @param R a data frame of reproduction numbers (column `R`) by date (column +#' `date`). Column `R` must be numeric and `date` must be in date format. If +#' not all days between the first and last day in the `date` are present, +#' it will be assumed that R stays the same until the next given date. +#' @param initial_infections numeric; the initial number of infections. +#' @param day_of_week_effect either `NULL` (no day of the week effect) or a +#' numerical vector of length specified in [obs_opts()] as `week_length` +#' (default: 7) if `week_effect` is set to TRUE. Each element of the vector +#' gives the weight given to reporting on this day (normalised to 1). +#' The default is `NULL`. +#' @param estimates deprecated; use [forecast_infections()] instead +#' @param ... deprecated; only included for backward compatibility +#' @inheritParams estimate_infections +#' @inheritParams rt_opts +#' @inheritParams stan_opts +#' @importFrom lifecycle deprecate_warn +#' @importFrom checkmate assert_data_frame assert_date assert_numeric +#' assert_subset +#' @importFrom data.table data.table merge.data.table nafill rbindlist +#' @return A data.table of simulated infections (variable `infections`) and +#' reported cases (variable `reported_cases`) by date. +#' @author Sebastian Funk #' @export -simulate_infections <- function(...) { +#' @examples +#' \donttest{ +#' R <- data.frame( +#' date = seq.Date(as.Date("2023-01-01"), length.out = 14, by = "day"), +#' R = c(rep(1.2, 7), rep(0.8, 7)) +#' ) +#' sim <- simulate_infections( +#' R = R, +#' initial_infections = 100, +#' generation_time = generation_time_opts( +#' fix_dist(example_generation_time) +#' ), +#' delays = delay_opts(fix_dist(example_reporting_delay)), +#' obs = obs_opts(family = "poisson") +#' ) +#' } +simulate_infections <- function(estimates, R, initial_infections, + day_of_week_effect = NULL, + generation_time = generation_time_opts(), + delays = delay_opts(), + truncation = trunc_opts(), + obs = obs_opts(), + CrIs = c(0.2, 0.5, 0.9), + backend = "rstan", + pop = 0, ...) { + ## deprecated usage + if (!missing(estimates)) { deprecate_warn( "2.0.0", - "simulate_infections()", + "simulate_infections(estimates)", "forecast_infections()", - "A new [simulate_infections()] function for simulating from given ", - "parameters is planned for implementation in the future." + details = paste0( + "This `estimates` option will be removed from [simulate_infections()] ", + "in version 2.1.0." + ) + ) + return(forecast_infections(estimates = estimates, ...)) + } + + ## check inputs + assert_data_frame(R, any.missing = FALSE) + assert_subset(colnames(R), c("date", "R")) + assert_date(R$date) + assert_numeric(R$R, lower = 0) + assert_numeric(initial_infections, lower = 0) + assert_numeric(day_of_week_effect, lower = 0, null.ok = TRUE) + assert_numeric(pop, lower = 0) + assert_class(delays, "delay_opts") + assert_class(obs, "obs_opts") + assert_class(generation_time, "generation_time_opts") + + ## create R for all dates modelled + all_dates <- data.table(date = seq.Date(min(R$date), max(R$date), by = "day")) + R <- merge.data.table(all_dates, R, by = "date", all.x = TRUE) + R <- R[, R := nafill(R, type = "locf")] + ## remove any initial NAs + R <- R[!is.na(R)] + + seeding_time <- get_seeding_time(delays, generation_time) + if (seeding_time > 1) { + ## estimate initial growth from initial reproduction number if seeding time + ## is greater than 1 + initial_growth <- (R$R[1] - 1) / mean(generation_time) + } else { + initial_growth <- numeric(0) + } + + data <- list( + n = 1, + t = nrow(R) + seeding_time, + seeding_time = seeding_time, + future_time = 0, + initial_infections = array(log(initial_infections), dim = c(1, 1)), + initial_growth = array(initial_growth, dim = c(1, length(initial_growth))), + R = array(R$R, dim = c(1, nrow(R))), + pop = pop + ) + + data <- c(data, create_stan_delays( + gt = generation_time, + delay = delays, + trunc = truncation + )) + + if ((length(data$delay_mean_sd) > 0 && any(data$delay_mean_sd > 0)) || + (length(data$delay_sd_sd) > 0 && any(data$delay_sd_sd > 0))) { + stop( + "Cannot simulate from uncertain parameters. Use the [fix_dist()] ", + "function to set the parameters of uncertain distributions either the ", + "mean or a randomly sampled value" ) - forecast_infections(...) + } + data$delay_mean <- array( + data$delay_mean_mean, dim = c(1, length(data$delay_mean_mean)) + ) + data$delay_sd <- array( + data$delay_sd_mean, dim = c(1, length(data$delay_sd_mean)) + ) + data$delay_mean_sd <- NULL + data$delay_sd_sd <- NULL + + data <- c(data, create_obs_model( + obs, dates = R$date + )) + + if (data$obs_scale_sd > 0) { + stop( + "Cannot simulate from uncertain observation scaling; use fixed scaling ", + "instead." + ) + } + if (data$obs_scale) { + data$frac_obs <- array(data$obs_scale_mean, dim = c(1, 1)) + } else { + data$frac_obs <- array(dim = c(1, 0)) + } + data$obs_scale_mean <- NULL + data$obs_scale_sd <- NULL + + if (obs$family == "negbin") { + if (data$phi_sd > 0) { + stop( + "Cannot simulate from uncertain overdispersion; use fixed ", + "overdispersion instead." + ) + } + data$rep_phi <- array(data$phi_mean, dim = c(1, 1)) + } else { + data$rep_phi <- array(dim = c(1, 0)) + } + data$phi_mean <- NULL + data$phi_sd <- NULL + + ## day of week effect + if (is.null(day_of_week_effect)) { + day_of_week_effect <- rep(1, data$week_effect) + } + + day_of_week_effect <- day_of_week_effect / sum(day_of_week_effect) + data$day_of_week_simplex <- array( + day_of_week_effect, dim = c(1, data$week_effect) + ) + + # Create stan arguments + stan <- stan_opts(backend = backend, chains = 1, samples = 1, warmup = 1) + args <- create_stan_args( + stan, data = data, fixed_param = TRUE, model = "simulate_infections", + verbose = FALSE + ) + + ## simulate + sim <- fit_model(args, id = "simulate_infections") + + ## join batches + dates <- c( + seq(min(R$date) - seeding_time, min(R$date) - 1, by = "day"), + R$date + ) + out <- extract_parameter_samples(sim, data, + reported_inf_dates = dates, + reported_dates = dates[-(1:seeding_time)], + drop_length_1 = TRUE + ) + + out <- rbindlist(out[c("infections", "reported_cases")], idcol = "variable") + out <- out[, c("sample", "parameter", "time") := NULL] + + return(out[]) } #' Forecast infections from a given fit and trajectory of the time-varying diff --git a/inst/stan/data/simulation_rt.stan b/inst/stan/data/simulation_rt.stan index a97300451..3beae3161 100644 --- a/inst/stan/data/simulation_rt.stan +++ b/inst/stan/data/simulation_rt.stan @@ -1,5 +1,5 @@ - array[seeding_time ? n : 0, 1] real initial_infections; // initial logged infections - array[seeding_time > 1 ? n : 0, 1] real initial_growth; //initial growth + array[n, 1] real initial_infections; // initial logged infections + array[n, seeding_time > 1 ? 1 : 0] real initial_growth; //initial growth matrix[n, t - seeding_time] R; // reproduction number int pop; // susceptible population diff --git a/inst/stan/simulate_infections.stan b/inst/stan/simulate_infections.stan index d8448d047..815b8e8e9 100644 --- a/inst/stan/simulate_infections.stan +++ b/inst/stan/simulate_infections.stan @@ -63,7 +63,9 @@ generated quantities { to_vector(infections[i]), delay_rev_pmf, seeding_time) ); } else { - reports[i] = to_row_vector(infections[(seeding_time + 1):t]); + reports[i] = to_row_vector( + infections[i, (seeding_time + 1):t] + ); } // weekly reporting effect @@ -72,6 +74,18 @@ generated quantities { day_of_week_effect(to_vector(reports[i]), day_of_week, to_vector(day_of_week_simplex[i]))); } + // truncate near time cases to observed reports + if (trunc_id) { + vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf( + trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id, + delay_types_groups, delay_max, delay_np_pmf, + delay_np_pmf_groups, delay_mean[i], delay_sd[i], delay_dist, + 0, 1, 1 + ); + reports[i] = to_row_vector(truncate( + to_vector(reports[i]), trunc_rev_cmf, 0) + ); + } // scale observations if (obs_scale) { reports[i] = to_row_vector(scale_obs(to_vector(reports[i]), frac_obs[i, 1])); @@ -81,8 +95,8 @@ generated quantities { to_vector(reports[i]), rep_phi[i], model_type ); { - real gt_mean = rev_pmf_mean(gt_rev_pmf, 1); - real gt_var = rev_pmf_var(gt_rev_pmf, 1, gt_mean); + real gt_mean = rev_pmf_mean(gt_rev_pmf, 0); + real gt_var = rev_pmf_var(gt_rev_pmf, 0, gt_mean); r[i] = R_to_growth(to_vector(R[i]), gt_mean, gt_var); } } diff --git a/man/simulate_infections.Rd b/man/simulate_infections.Rd index f1f3f577d..0f6b6ab12 100644 --- a/man/simulate_infections.Rd +++ b/man/simulate_infections.Rd @@ -2,19 +2,102 @@ % Please edit documentation in R/simulate_infections.R \name{simulate_infections} \alias{simulate_infections} -\title{Deprecated; use \code{\link[=forecast_infections]{forecast_infections()}} instead} +\title{Simulate infections using the renewal equation} \usage{ -simulate_infections(...) +simulate_infections( + estimates, + R, + initial_infections, + day_of_week_effect = NULL, + generation_time = generation_time_opts(), + delays = delay_opts(), + truncation = trunc_opts(), + obs = obs_opts(), + CrIs = c(0.2, 0.5, 0.9), + backend = "rstan", + pop = 0, + ... +) } \arguments{ -\item{...}{Arguments to be passed to \code{\link[=forecast_infections]{forecast_infections()}}} +\item{estimates}{deprecated; use \code{\link[=forecast_infections]{forecast_infections()}} instead} + +\item{R}{a data frame of reproduction numbers (column \code{R}) by date (column +\code{date}). Column \code{R} must be numeric and \code{date} must be in date format. If +not all days between the first and last day in the \code{date} are present, +it will be assumed that R stays the same until the next given date.} + +\item{initial_infections}{numeric; the initial number of infections.} + +\item{day_of_week_effect}{either \code{NULL} (no day of the week effect) or a +numerical vector of length specified in \code{\link[=obs_opts]{obs_opts()}} as \code{week_length} +(default: 7) if \code{week_effect} is set to TRUE. Each element of the vector +gives the weight given to reporting on this day (normalised to 1). +The default is \code{NULL}.} + +\item{generation_time}{A call to \code{\link[=generation_time_opts]{generation_time_opts()}} defining the +generation time distribution used. For backwards compatibility a list of +summary parameters can also be passed.} + +\item{delays}{A call to \code{\link[=delay_opts]{delay_opts()}} defining delay distributions and +options. See the documentation of \code{\link[=delay_opts]{delay_opts()}} and the examples below for +details.} + +\item{truncation}{A call to \code{\link[=trunc_opts]{trunc_opts()}} defining the truncation of +observed data. Defaults to \code{\link[=trunc_opts]{trunc_opts()}}. See \code{\link[=estimate_truncation]{estimate_truncation()}} for +an approach to estimating truncation from data.} + +\item{obs}{A list of options as generated by \code{\link[=obs_opts]{obs_opts()}} defining the +observation model. Defaults to \code{\link[=obs_opts]{obs_opts()}}.} + +\item{CrIs}{Numeric vector of credible intervals to calculate.} + +\item{backend}{Character string indicating the backend to use for fitting +stan models. Supported arguments are "rstan" (default) or "cmdstanr".} + +\item{pop}{Integer, defaults to 0. Susceptible population initially present. +Used to adjust Rt estimates when otherwise fixed based on the proportion of +the population that is susceptible. When set to 0 no population adjustment +is done.} + +\item{...}{deprecated; only included for backward compatibility} } \value{ -the result of \code{\link[=forecast_infections]{forecast_infections()}} +A data.table of simulated infections (variable \code{infections}) and +reported cases (variable \code{reported_cases}) by date. } \description{ -\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#deprecated}{\figure{lifecycle-deprecated.svg}{options: alt='[Deprecated]'}}}{\strong{[Deprecated]}} +Simulations are done from given initial infections and, potentially +time-varying, reproduction numbers. Delays and parameters of the observation +model can be specified using the same options as in \code{\link[=estimate_infections]{estimate_infections()}}. } \details{ -Calling this function passes all arguments to \code{\link[=forecast_infections]{forecast_infections()}} +In order to simulate, all parameters that are specified such as the mean and +standard deviation of delays or observation scaling, must be fixed. +Uncertain parameters are not allowed. + +A previous function called \code{\link[=simulate_infections]{simulate_infections()}} that simulates from a +given model fit has been renamed \code{\link[=forecast_infections]{forecast_infections()}}. Using +\code{\link[=simulate_infections]{simulate_infections()}} with existing estimates is now deprecated. This +option will be removed in version 2.1.0. +} +\examples{ +\donttest{ + R <- data.frame( + date = seq.Date(as.Date("2023-01-01"), length.out = 14, by = "day"), + R = c(rep(1.2, 7), rep(0.8, 7)) + ) + sim <- simulate_infections( + R = R, + initial_infections = 100, + generation_time = generation_time_opts( + fix_dist(example_generation_time) + ), + delays = delay_opts(fix_dist(example_reporting_delay)), + obs = obs_opts(family = "poisson") + ) +} +} +\author{ +Sebastian Funk } diff --git a/tests/testthat/_snaps/simulate-infections.md b/tests/testthat/_snaps/simulate-infections.md new file mode 100644 index 000000000..5f0ba10c8 --- /dev/null +++ b/tests/testthat/_snaps/simulate-infections.md @@ -0,0 +1,68 @@ +# simulate_infections works as expected with standard parameters + + variable date value + + 1: infections 2023-01-01 120.00000 + 2: infections 2023-01-02 144.00000 + 3: infections 2023-01-03 172.80000 + 4: infections 2023-01-04 207.36000 + 5: infections 2023-01-05 248.83200 + 6: infections 2023-01-06 298.59840 + 7: infections 2023-01-07 358.31808 + 8: infections 2023-01-08 286.65446 + 9: infections 2023-01-09 229.32357 + 10: infections 2023-01-10 183.45886 + 11: infections 2023-01-11 146.76709 + 12: infections 2023-01-12 117.41367 + 13: infections 2023-01-13 93.93093 + 14: infections 2023-01-14 75.14475 + 15: reported_cases 2023-01-01 128.00000 + 16: reported_cases 2023-01-02 151.00000 + 17: reported_cases 2023-01-03 145.00000 + 18: reported_cases 2023-01-04 188.00000 + 19: reported_cases 2023-01-05 252.00000 + 20: reported_cases 2023-01-06 276.00000 + 21: reported_cases 2023-01-07 371.00000 + 22: reported_cases 2023-01-08 273.00000 + 23: reported_cases 2023-01-09 234.00000 + 24: reported_cases 2023-01-10 192.00000 + 25: reported_cases 2023-01-11 157.00000 + 26: reported_cases 2023-01-12 120.00000 + 27: reported_cases 2023-01-13 78.00000 + 28: reported_cases 2023-01-14 63.00000 + variable date value + +# simulate_infections works as expected with additional parameters + + variable date value + + 1: infections 2023-01-01 240.6049 + 2: infections 2023-01-02 253.8951 + 3: infections 2023-01-03 267.7099 + 4: infections 2023-01-04 282.1928 + 5: infections 2023-01-05 297.4201 + 6: infections 2023-01-06 313.4477 + 7: infections 2023-01-07 330.3258 + 8: infections 2023-01-08 232.0685 + 9: infections 2023-01-09 222.6294 + 10: infections 2023-01-10 212.0540 + 11: infections 2023-01-11 201.3581 + 12: infections 2023-01-12 190.8991 + 13: infections 2023-01-13 180.8132 + 14: infections 2023-01-14 171.1484 + 15: reported_cases 2023-01-01 425.0000 + 16: reported_cases 2023-01-02 335.0000 + 17: reported_cases 2023-01-03 376.0000 + 18: reported_cases 2023-01-04 250.0000 + 19: reported_cases 2023-01-05 301.0000 + 20: reported_cases 2023-01-06 275.0000 + 21: reported_cases 2023-01-07 844.0000 + 22: reported_cases 2023-01-08 235.0000 + 23: reported_cases 2023-01-09 205.0000 + 24: reported_cases 2023-01-10 251.0000 + 25: reported_cases 2023-01-11 239.0000 + 26: reported_cases 2023-01-12 80.0000 + 27: reported_cases 2023-01-13 128.0000 + 28: reported_cases 2023-01-14 276.0000 + variable date value + diff --git a/tests/testthat/test-simulate_infections.R b/tests/testthat/test-forecast-infections.R similarity index 94% rename from tests/testthat/test-simulate_infections.R rename to tests/testthat/test-forecast-infections.R index 3221e2721..3bd49b509 100644 --- a/tests/testthat/test-simulate_infections.R +++ b/tests/testthat/test-forecast-infections.R @@ -51,7 +51,7 @@ test_that("forecast_infections works to simulate a passed in estimate_infections expect_equal(tail(sims$summarised[variable == "R"]$median, 30), R[41:70]) }) -test_that("simulate infections can be run with a limited number of samples", { +test_that("forecast infections can be run with a limited number of samples", { R <- c(rep(NA_real_, 40), rep(1.2, 15), rep(0.8, 15)) sims <- forecast_infections(out, R, samples = 10) expect_equal(names(sims), c("samples", "summarised", "observations")) @@ -59,7 +59,7 @@ test_that("simulate infections can be run with a limited number of samples", { expect_equal(max(sims$samples$sample), 10) }) -test_that("simulate infections fails as expected", { +test_that("forecast infections fails as expected", { expect_error(forecast_infections()) expect_error(forecast_infections(out[-"fit"])) }) @@ -79,7 +79,7 @@ test_that("forecast_infections works to simulate a passed in estimate_infections expect_equal(names(sims_sample), c("samples", "summarised", "observations")) }) -test_that("simulate_infections is deprecated", { +test_that("simulate_infections with a given estimate is deprecated", { expect_deprecated( sims <- simulate_infections(out) ) diff --git a/tests/testthat/test-simulate-infections.R b/tests/testthat/test-simulate-infections.R new file mode 100644 index 000000000..758c93fde --- /dev/null +++ b/tests/testthat/test-simulate-infections.R @@ -0,0 +1,51 @@ +skip_on_cran() + +R <- data.frame( + date = seq.Date(as.Date("2023-01-01"), length.out = 14, by = "day"), + R = c(rep(1.2, 7), rep(0.8, 7)) +) +initial_infections <- 100 +test_simulate_infections <- function(obs = obs_opts(family = "poisson"), ...) { + sim <- simulate_infections( + R = R, + initial_infections = 100, + obs = obs, + ... + ) + return(sim) +} + +test_that("simulate_infections works as expected with standard parameters", { + sim <- test_simulate_infections() + expect_equal(nrow(sim), 2 * nrow(R)) + expect_snapshot_output(sim) +}) + +test_that("simulate_infections works as expected with additional parameters", { + sim <- test_simulate_infections( + generation_time = generation_time_opts(fix_dist(example_generation_time)), + delays = delay_opts(fix_dist(example_reporting_delay)), + obs = obs_opts(family = "negbin", phi = c(0.5, 0)) + ) + expect_equal(nrow(sim), 2 * nrow(R)) + expect_snapshot_output(sim) +}) + +test_that("simulate_infections fails with uncertain parameters", { + expect_error( + test_simulate_infections(obs = obs_opts(family = "negbin")), + "uncertain" + ) + expect_error( + test_simulate_infections( + obs = obs_opts(scale = list(mean = 1, sd = 1)) + ), + "uncertain" + ) + expect_error( + test_simulate_infections( + delays = delay_opts(example_incubation_period) + ), + "uncertain" + ) +})