diff --git a/DESCRIPTION b/DESCRIPTION index 93551fa60..a675ac158 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -99,6 +99,7 @@ BugReports: https://github.com/epiforecasts/EpiNow2/issues Depends: R (>= 3.5.0) Imports: + checkmate, data.table, futile.logger (>= 1.4), future, diff --git a/NAMESPACE b/NAMESPACE index fba91c11d..c96dfd4a5 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -100,7 +100,20 @@ export(update_secondary_args) import(Rcpp) import(methods) import(rstantools) +importFrom(R.utils,isDirectory) importFrom(R.utils,withTimeout) +importFrom(checkmate,assert_character) +importFrom(checkmate,assert_class) +importFrom(checkmate,assert_data_frame) +importFrom(checkmate,assert_date) +importFrom(checkmate,assert_integerish) +importFrom(checkmate,assert_logical) +importFrom(checkmate,assert_names) +importFrom(checkmate,assert_numeric) +importFrom(checkmate,assert_path_for_output) +importFrom(checkmate,assert_string) +importFrom(checkmate,test_data_frame) +importFrom(checkmate,test_numeric) importFrom(data.table,":=") importFrom(data.table,.N) importFrom(data.table,as.data.table) @@ -186,6 +199,7 @@ importFrom(purrr,safely) importFrom(purrr,transpose) importFrom(purrr,walk) importFrom(rlang,abort) +importFrom(rlang,arg_match) importFrom(rlang,cnd_muffle) importFrom(rlang,warn) importFrom(rstan,expose_stan_functions) diff --git a/NEWS.md b/NEWS.md index ffa6c3a35..5dc3706a0 100644 --- a/NEWS.md +++ b/NEWS.md @@ -18,6 +18,7 @@ * Reduced the number of long-running examples. By @sbfnk in #459 and reviewed by @seabbs. * Changed all instances of arguments that refer to the maximum of a distribution to reflect the maximum. Previously this did, in some instance, refer to the length of the PMF. By @sbfnk in #468. * Fixed a bug in the bounds of delays when setting initial conditions. By @sbfnk in #474. +* Added input checking to `estimate_infections()`, `estimate_secondary()`, `estimate_truncation()`, `simulate_infections()`, and `epinow()`. `check_reports_valid()` has been added to validate the reports dataset passed to these functions. Tests are added to check `check_reports_valid()`. As part of input validation, the various `*_opts()` functions now return subclasses of the same name as the functions and are tested against passed arguments to ensure the right `*_opts()` is passed to the right argument. For example, the `obs` argument in `estimate_secondary()` is expected to only receive arguments passed through `obs_opts()` and will error otherwise. By @jamesmbaazam in #476 and reviewed by @sbfnk and @seabbs. ## Model changes diff --git a/R/checks.R b/R/checks.R new file mode 100644 index 000000000..ab08fe691 --- /dev/null +++ b/R/checks.R @@ -0,0 +1,56 @@ +#' Validate data input +#' +#' @description +#' `check_reports_valid()` checks that the supplied data is a ``, +#' and that it has the right column names and types. In particular, it checks +#' that the date column is in date format and does not contain NA's, and that +#' the other columns are numeric. +#' +#' @param reports A data frame with either: +#' * a minimum of two columns: `date` and `confirm`, if to be +#' used by [estimate_infections()] or [estimate_truncation()], or +#' * a minimum of three columns: `date`, `primary`, and `secondary`, if to be +#' used by [estimate_secondary()]. +#' @param model The EpiNow2 model to be used. Either +#' "estimate_infections", "estimate_truncation", or "estimate_secondary". +#' This is used to determine which checks to perform on the data input. +#' @importFrom checkmate assert_data_frame assert_date assert_names +#' assert_numeric +#' @importFrom rlang arg_match +#' @return Called for its side effects. +#' @author James M. Azam +#' @keywords internal +check_reports_valid <- function(reports, model) { + # Check that the case time series (reports) is a data frame + assert_data_frame(reports) + # Perform checks depending on the model to the data is meant to be used with + model <- arg_match( + model, + values = c( + "estimate_infections", + "estimate_truncation", + "estimate_secondary" + ) + ) + + if (model == "estimate_secondary") { + # Check that reports has the right column names + assert_names( + names(reports), + must.include = c("date", "primary", "secondary") + ) + # Check that the reports data.frame has the right column types + assert_date(reports$date, any.missing = FALSE) + assert_numeric(reports$primary, lower = 0) + assert_numeric(reports$secondary, lower = 0) + } else { + # Check that reports has the right column names + assert_names( + names(reports), + must.include = c("date", "confirm") + ) + # Check that the reports data.frame has the right column types + assert_date(reports$date, any.missing = FALSE) + assert_numeric(reports$confirm, lower = 0) + } +} diff --git a/R/create.R b/R/create.R index 4ce04559f..82652a652 100644 --- a/R/create.R +++ b/R/create.R @@ -170,9 +170,9 @@ create_future_rt <- function(future = "latest", delay = 0) { "estimate" ) ) - if (!(future %in% "project")) { + if (!(future == "project")) { out$fixed <- TRUE - out$from <- ifelse(future %in% "latest", 0, -delay) + out$from <- ifelse(future == "latest", 0, -delay) } } else if (is.numeric(future)) { out$fixed <- TRUE @@ -227,7 +227,7 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL, # apply random walk if (rt$rw != 0) { breakpoints <- as.integer(seq_along(breakpoints) %% rt$rw == 0) - if (!(rt$future %in% "project")) { + if (!(rt$future == "project")) { max_bps <- length(breakpoints) - horizon + future_rt$from if (max_bps < length(breakpoints)) { breakpoints[(max_bps + 1):length(breakpoints)] <- 0 @@ -248,7 +248,7 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL, future_fixed = as.numeric(future_rt$fixed), fixed_from = future_rt$from, pop = rt$pop, - stationary = as.numeric(rt$gp_on %in% "R0"), + stationary = as.numeric(rt$gp_on == "R0"), future_time = horizon - future_rt$from ) return(rt_data) @@ -383,7 +383,7 @@ create_gp_data <- function(gp = gp_opts(), data) { #' create_obs_model(obs_opts(week_length = 3), dates = dates) create_obs_model <- function(obs = obs_opts(), dates) { data <- list( - model_type = as.numeric(obs$family %in% "negbin"), + model_type = as.numeric(obs$family == "negbin"), phi_mean = obs$phi[1], phi_sd = obs$phi[2], week_effect = ifelse(obs$week_effect, obs$week_length, 1), diff --git a/R/epinow.R b/R/epinow.R index a7bb4064d..7da0b4398 100644 --- a/R/epinow.R +++ b/R/epinow.R @@ -33,6 +33,9 @@ #' @importFrom lubridate days #' @importFrom futile.logger flog.fatal flog.warn flog.error flog.debug ftry #' @importFrom rlang cnd_muffle +#' @importFrom checkmate assert_string assert_path_for_output +#' assert_date assert_logical +#' @importFrom R.utils isDirectory #' @author Sam Abbott #' @examples #' \donttest{ @@ -105,6 +108,17 @@ epinow <- function(reported_cases, plot_args = list(), target_folder = NULL, target_date, logs = tempdir(), id = "epinow", verbose = interactive()) { + # Check inputs + assert_logical(return_output) + stopifnot("target_folder is not a directory" = + !is.null(target_folder) || isDirectory(target_folder) + ) + if (!missing(target_date)) { + assert_string(target_date) + } + assert_string(id) + assert_logical(verbose) + if (is.null(target_folder)) { return_output <- TRUE } @@ -251,7 +265,7 @@ epinow <- function(reported_cases, } ), error = function(e) { - if (id %in% "epinow") { + if (id == "epinow") { stop(e) } else { error_text <- sprintf("%s: %s - %s", id, e$message, toString(e$call)) @@ -269,15 +283,15 @@ epinow <- function(reported_cases, } if (!is.null(target_folder) && !is.null(out$error)) { - saveRDS(out$error, paste0(target_folder, "/error.rds")) - saveRDS(out$trace, paste0(target_folder, "/trace.rds")) + saveRDS(out$error, file.path(target_folder, "error.rds")) + saveRDS(out$trace, file.path(target_folder, "trace.rds")) } # log timing if specified if (output["timing"]) { out$timing <- round(as.numeric(end_time - start_time), 1) if (!is.null(target_folder)) { - saveRDS(out$timing, paste0(target_folder, "/runtime.rds")) + saveRDS(out$timing, file.path(target_folder, "runtime.rds")) } } diff --git a/R/estimate_infections.R b/R/estimate_infections.R index 73a29cf99..e39429cd2 100644 --- a/R/estimate_infections.R +++ b/R/estimate_infections.R @@ -65,6 +65,8 @@ #' @importFrom lubridate days #' @importFrom purrr transpose #' @importFrom futile.logger flog.threshold flog.warn flog.debug +#' @importFrom checkmate assert_class assert_numeric assert_logical +#' assert_string #' @examples #' \donttest{ #' # set number of cores to use @@ -131,6 +133,24 @@ estimate_infections <- function(reported_cases, weigh_delay_priors = TRUE, id = "estimate_infections", verbose = interactive()) { + # Validate inputs + check_reports_valid(reported_cases, model = "estimate_infections") + assert_class(generation_time, "generation_time_opts") + assert_class(delays, "delay_opts") + assert_class(truncation, "trunc_opts") + assert_class(rt, "rt_opts", null.ok = TRUE) + assert_class(backcalc, "backcalc_opts") + assert_class(gp, "gp_opts", null.ok = TRUE) + assert_class(obs, "obs_opts") + assert_class(stan, "stan_opts") + assert_numeric(horizon, lower = 0) + assert_numeric(CrIs, lower = 0, upper = 1) + assert_logical(filter_leading_zeros) + assert_numeric(zero_threshold, lower = 0) + assert_logical(weigh_delay_priors) + assert_string(id) + assert_logical(verbose) + set_dt_single_thread() # store dirty reported case data @@ -211,7 +231,7 @@ estimate_infections <- function(reported_cases, # Initialise fitting by using a previous fit or fitting to cumulative cases if (!is.null(args$init_fit)) { if (!inherits(args$init_fit, "stanfit") && - args$init_fit %in% "cumulative") { + args$init_fit == "cumulative") { args$init_fit <- init_cumulative_fit(args, warmup = 50, samples = 50, id = id, verbose = FALSE @@ -435,22 +455,12 @@ fit_model_with_nuts <- function(args, future = FALSE, max_execution_time = Inf, } } - if (!future) { - fit <- fit_chain(1, - stan_args = args, max_time = max_execution_time, - catch = !id %in% c("estimate_infections", "epinow") - ) - if (stuck_chains > 0) { - fit <- NULL - } - if (is.null(fit)) { - rlang::abort("model fitting was timed out or failed") - } - } else { + if (future) { chains <- args$chains args$chains <- 1 args$cores <- 1 - fits <- future.apply::future_lapply(1:chains, fit_chain, + fits <- future.apply::future_lapply(1:chains, + fit_chain, stan_args = args, max_time = max_execution_time, catch = TRUE, @@ -478,12 +488,23 @@ fit_model_with_nuts <- function(args, future = FALSE, max_execution_time = Inf, if ((chains - failed_chains) < 2) { rlang::abort( "model fitting failed as too few chains were returned to assess", - " convergence (2 or more required)" + " convergence (2 or more required)" ) } } fit <- rstan::sflist2stanfit(fit) } + } else { + fit <- fit_chain(1, + stan_args = args, max_time = max_execution_time, + catch = !id %in% c("estimate_infections", "epinow") + ) + if (stuck_chains > 0) { + fit <- NULL + } + if (is.null(fit)) { + rlang::abort("model fitting was timed out or failed") + } } return(fit) } diff --git a/R/estimate_secondary.R b/R/estimate_secondary.R index 5718aa08f..c720fa0fa 100644 --- a/R/estimate_secondary.R +++ b/R/estimate_secondary.R @@ -66,6 +66,8 @@ #' @importFrom lubridate wday #' @importFrom data.table as.data.table merge.data.table #' @importFrom utils modifyList +#' @importFrom checkmate assert_class assert_numeric assert_data_frame +#' assert_logical #' @examples #' \donttest{ #' # set number of cores to use @@ -149,6 +151,19 @@ estimate_secondary <- function(reports, weigh_delay_priors = FALSE, verbose = interactive(), ...) { + # Validate the inputs + check_reports_valid(reports, model = "estimate_secondary") + assert_class(secondary, "secondary_opts") + assert_class(delays, "delay_opts") + assert_class(truncation, "trunc_opts") + assert_class(obs, "obs_opts") + assert_numeric(burn_in, lower = 0) + assert_numeric(CrIs, lower = 0, upper = 1) + assert_data_frame(priors, null.ok = TRUE) + assert_class(model, "stanfit", null.ok = TRUE) + assert_logical(weigh_delay_priors) + assert_logical(verbose) + reports <- data.table::as.data.table(reports) if (burn_in >= nrow(reports)) { @@ -238,14 +253,14 @@ estimate_secondary <- function(reports, #' options that can be passed. #' #' @seealso estimate_secondary -#' @return A list of binary options summarising secondary model used in -#' `estimate_secondary()`. Options returned are `cumulative` (should the -#' secondary report be cumulative), `historic` (should a convolution of primary -#' reported cases be used to predict secondary reported cases), -#' `primary_hist_additive` (should the historic convolution of primary reported -#' cases be additive or subtractive), `current` (should currently observed -#' primary reported cases contribute to current secondary reported cases), -#' `primary_current_additive` (should current primary reported cases be +#' @return A `` object of binary options summarising secondary +#' model used in `estimate_secondary()`. Options returned are `cumulative` +#' (should the secondary report be cumulative), `historic` (should a +#' convolution of primary reported cases be used to predict secondary reported +#' cases), `primary_hist_additive` (should the historic convolution of primary +#' reported cases be additive or subtractive), `current` (should currently +#' observed primary reported cases contribute to current secondary reported +#' cases), `primary_current_additive` (should current primary reported cases be #' additive or subtractive). #' #' @export @@ -258,7 +273,7 @@ estimate_secondary <- function(reports, #' secondary_opts("prevalence") secondary_opts <- function(type = "incidence", ...) { type <- match.arg(type, choices = c("incidence", "prevalence")) - if (type %in% "incidence") { + if (type == "incidence") { data <- list( cumulative = 0, historic = 1, @@ -266,7 +281,7 @@ secondary_opts <- function(type = "incidence", ...) { current = 0, primary_current_additive = 0 ) - } else if (type %in% "prevalence") { + } else if (type == "prevalence") { data <- list( cumulative = 1, historic = 1, @@ -276,6 +291,7 @@ secondary_opts <- function(type = "incidence", ...) { ) } data <- modifyList(data, list(...)) + attr(data, "class") <- c("secondary_opts", class(data)) return(data) } diff --git a/R/estimate_truncation.R b/R/estimate_truncation.R index f63fa82e8..a6a8215c8 100644 --- a/R/estimate_truncation.R +++ b/R/estimate_truncation.R @@ -71,10 +71,12 @@ #' @export #' @inheritParams calc_CrIs #' @inheritParams estimate_infections -#' @importFrom purrr map reduce map_dbl +#' @importFrom purrr map reduce map_dbl walk #' @importFrom rstan sampling #' @importFrom data.table copy .N as.data.table merge.data.table setDT #' @importFrom data.table setcolorder +#' @importFrom checkmate assert_character assert_numeric assert_class +#' assert_logical #' @examples #' # set number of cores to use #' old_opts <- options() @@ -148,6 +150,13 @@ estimate_truncation <- function(obs, max_truncation, trunc_max = 10, weigh_delay_priors = FALSE, verbose = TRUE, ...) { + # Validate inputs + walk(obs, check_reports_valid, model = "estimate_truncation") + assert_class(truncation, "dist_spec") + assert_class(model, "stanfit", null.ok = TRUE) + assert_numeric(CrIs, lower = 0, upper = 1) + assert_logical(weigh_delay_priors) + assert_logical(verbose) ## code block to remove in EpiNow2 2.0.0 construct_trunc <- FALSE diff --git a/R/opts.R b/R/opts.R index 8e2101ad5..5a1ffa6c5 100644 --- a/R/opts.R +++ b/R/opts.R @@ -15,7 +15,8 @@ #' @param prior_weight deprecated; prior weights are now specified as a #' model option. Use the `weigh_delay_priors` argument of `estimate_infections` #' instead. -#' @return A list summarising the input delay distributions. +#' @return A `` object summarising the input delay +#' distributions. #' @author Sebastian Funk #' @author Sam Abbott #' @seealso convert_to_logmean convert_to_logsd bootstrapped_dist_fit dist_spec @@ -93,6 +94,7 @@ generation_time_opts <- function(dist = dist_spec(mean = 1), ..., "information, see the relevant documentation pages using ", "`?generation_time_opts`") } + attr(dist, "class") <- c("generation_time_opts", class(dist)) return(dist) } @@ -105,7 +107,7 @@ generation_time_opts <- function(dist = dist_spec(mean = 1), ..., #' using [dist_spec()]. Default is an empty call to [dist_spec()], i.e. no delay #' @param ... deprecated; use `dist` instead #' @param fixed deprecated; use `dist` instead -#' @return A list summarising the input delay distributions. +#' @return A `` object summarising the input delay distributions. #' @author Sam Abbott #' @author Sebastian Funk #' @seealso convert_to_logmean convert_to_logsd bootstrapped_dist_fit dist_spec @@ -155,6 +157,7 @@ delay_opts <- function(dist = dist_spec(), ..., fixed = FALSE) { ## can be removed once dot options are hard deprecated stop("Unknown named arguments passed to `delay_opts`") } + attr(dist, "class") <- c("delay_opts", class(dist)) return(dist) } @@ -168,7 +171,8 @@ delay_opts <- function(dist = dist_spec(), ..., fixed = FALSE) { #' @param dist A delay distribution or series of delay distributions reflecting #' the truncation generated using [dist_spec()] or [estimate_truncation()]. #' Default is an empty call to [dist_spec()], i.e. no truncation -#' @return A list summarising the input truncation distribution. +#' @return A `` object summarising the input truncation +#' distribution. #' #' @author Sam Abbott #' @author Sebastian Funk @@ -196,6 +200,7 @@ trunc_opts <- function(dist = dist_spec()) { "`?trunc_opts`" ) } + attr(dist, "class") <- c("trunc_opts", class(dist)) return(dist) } @@ -237,7 +242,8 @@ trunc_opts <- function(dist = dist_spec()) { #' but the method relying on a global mean will revert to this for real time #' estimates, which may not be desirable. #' -#' @return A list of settings defining the time-varying reproduction number. +#' @return An `` object with settings defining the time-varying +#' reproduction number. #' @author Sam Abbott #' @inheritParams create_future_rt #' @export @@ -275,6 +281,7 @@ rt_opts <- function(prior = list(mean = 1, sd = 1), if (!("mean" %in% names(rt$prior) && "sd" %in% names(rt$prior))) { stop("prior must have both a mean and sd specified") } + attr(rt, "class") <- c("rt_opts", class(rt)) return(rt) } @@ -306,7 +313,7 @@ rt_opts <- function(prior = list(mean = 1, sd = 1), #' average to use when estimating Rt. This must be odd so that the central #' estimate is included. #' -#' @return A list of back calculation settings. +#' @return A `` object of back calculation settings. #' @author Sam Abbott #' @export #' @examples @@ -324,6 +331,7 @@ backcalc_opts <- function(prior = "reports", prior_window = 14, rt_window = 1) { estimate" ) } + attr(backcalc, "class") <- c("backcalc_opts", class(backcalc)) return(backcalc) } @@ -370,7 +378,7 @@ backcalc_opts <- function(prior = "reports", prior_window = 14, rt_window = 1) { #' approximate Gaussian process. See (Riutort-Mayol et al. 2020 #' ) for advice on updating this default. #' -#' @return A list of settings defining the Gaussian process +#' @return A `` object of settings defining the Gaussian process #' @author Sam Abbott #' @export #' @examples @@ -403,6 +411,7 @@ gp_opts <- function(basis_prop = 0.2, if (gp$matern_type != 3 / 2) { stop("only the Matern 3/2 kernel is currently supported") # nolint } + attr(gp, "class") <- c("gp_opts", class(gp)) return(gp) } @@ -437,7 +446,7 @@ gp_opts <- function(basis_prop = 0.2, #' @param return_likelihood Logical, defaults to `FALSE`. Should the likelihood #' be returned by the model. #' -#' @return A list of observation model settings. +#' @return An `` object of observation model settings. #' @author Sam Abbott #' @export #' @examples @@ -478,6 +487,7 @@ obs_opts <- function(family = "negbin", stop("If specifying a scale both a mean and sd are needed") } } + attr(obs, "class") <- c("obs_opts", class(obs)) return(obs) } @@ -538,6 +548,7 @@ rstan_sampling_opts <- function(cores = getOption("mc.cores", 1L), future = FALSE, max_execution_time = Inf, ...) { + dot_args <- list(...) opts <- list( cores = cores, warmup = warmup, @@ -549,8 +560,9 @@ rstan_sampling_opts <- function(cores = getOption("mc.cores", 1L), ) control_def <- list(adapt_delta = 0.95, max_treedepth = 15) opts$control <- modifyList(control_def, control) + dot_args$iter <- NULL opts$iter <- ceiling(samples / opts$chains) + opts$warmup - opts <- c(opts, ...) + opts <- c(opts, dot_args) return(opts) } @@ -625,9 +637,9 @@ rstan_opts <- function(object = NULL, object = object, method = method ) - if (method %in% "sampling") { + if (method == "sampling") { opts <- c(opts, rstan_sampling_opts(samples = samples, ...)) - } else if (method %in% "vb") { + } else if (method == "vb") { opts <- c(opts, rstan_vb_opts(samples = samples, ...)) } return(opts) @@ -661,7 +673,8 @@ rstan_opts <- function(object = NULL, #' #' @param ... Additional parameters to pass underlying option functions. #' -#' @return A list of arguments to pass to the appropriate rstan functions. +#' @return A `` object of arguments to pass to the appropriate +#' rstan functions. #' @author Sam Abbott #' @export #' @inheritParams rstan_opts @@ -678,7 +691,7 @@ stan_opts <- function(samples = 2000, return_fit = TRUE, ...) { backend <- match.arg(backend, choices = "rstan") - if (backend %in% "rstan") { + if (backend == "rstan") { opts <- rstan_opts( samples = samples, ... @@ -691,6 +704,7 @@ stan_opts <- function(samples = 2000, opts$init_fit <- init_fit } opts <- c(opts, list(return_fit = return_fit)) + attr(opts, "class") <- c("stan_opts", class(opts)) return(opts) } diff --git a/R/simulate_infections.R b/R/simulate_infections.R index ada43c60c..8adf2db27 100644 --- a/R/simulate_infections.R +++ b/R/simulate_infections.R @@ -34,6 +34,8 @@ #' @importFrom progressr with_progress progressor #' @importFrom data.table rbindlist as.data.table #' @importFrom lubridate days +#' @importFrom checkmate assert_class assert_names test_numeric test_data_frame +#' assert_numeric assert_integerish assert_logical #' @return A list of output as returned by [estimate_infections()] but based on #' results from the specified scenario rather than fitting. #' @export @@ -89,10 +91,22 @@ simulate_infections <- function(estimates, samples = NULL, batch_size = 10, verbose = interactive()) { - ## check batch size - if (!is.null(batch_size) && batch_size <= 1) { - stop("batch_size must be greater than 1") + ## check inputs + assert_class(estimates, "estimate_infections") + assert_names(names(estimates), must.include = "fit") + stopifnot( + "R must either be a numeric vector or a data.frame" = + test_numeric(R, lower = 0, null.ok = TRUE) || + test_data_frame(R, null.ok = TRUE) + ) + if (test_data_frame(R)) { + assert_names(names(R), must.include = c("date", "value")) + assert_numeric(R$value, lower = 0) } + assert_class(model, "stanfit", null.ok = TRUE) + assert_integerish(samples, lower = 1, null.ok = TRUE) + assert_integerish(batch_size, lower = 2) + assert_logical(verbose) ## extract samples from given stanfit object draws <- extract(estimates$fit, pars = c( diff --git a/README.md b/README.md index 430ce61f7..8c6ccd09d 100644 --- a/README.md +++ b/README.md @@ -268,13 +268,13 @@ parameters at the latest date partially supported by data. knitr::kable(summary(estimates)) ``` -| measure | estimate | -|:--------------------------------------|:------------------------| -| New confirmed cases by infection date | 2289 (1115 – 4367) | -| Expected change in daily cases | Likely decreasing | -| Effective reproduction no. | 0.89 (0.62 – 1.2) | -| Rate of growth | -0.025 (-0.096 – 0.036) | -| Doubling/halving time (days) | -27 (19 – -7.2) | +| measure | estimate | +|:--------------------------------------|:----------------------| +| New confirmed cases by infection date | 2243 (1144 – 4190) | +| Expected change in daily cases | Likely decreasing | +| Effective reproduction no. | 0.88 (0.6 – 1.2) | +| Rate of growth | -0.028 (-0.1 – 0.035) | +| Doubling/halving time (days) | -25 (20 – -6.9) | Summarised parameter estimates can also easily be returned, either filtered for a single parameter or for all parameters. @@ -282,19 +282,19 @@ filtered for a single parameter or for all parameters. ``` r head(summary(estimates, type = "parameters", params = "R")) #> date variable strat type median mean sd lower_90 -#> 1: 2020-02-22 R estimate 2.212959 2.220541 0.14576748 1.988091 -#> 2: 2020-02-23 R estimate 2.182545 2.186632 0.12162401 1.992521 -#> 3: 2020-02-24 R estimate 2.150369 2.150836 0.10169055 1.986534 -#> 4: 2020-02-25 R estimate 2.113105 2.113283 0.08594870 1.974178 -#> 5: 2020-02-26 R estimate 2.073614 2.074148 0.07420031 1.955402 -#> 6: 2020-02-27 R estimate 2.034058 2.033642 0.06600593 1.928757 +#> 1: 2020-02-22 R estimate 2.227000 2.234991 0.14475898 2.012280 +#> 2: 2020-02-23 R estimate 2.194586 2.199036 0.11981031 2.012784 +#> 3: 2020-02-24 R estimate 2.158077 2.160968 0.09983093 2.002390 +#> 4: 2020-02-25 R estimate 2.120294 2.120978 0.08450716 1.986945 +#> 5: 2020-02-26 R estimate 2.080098 2.079349 0.07331063 1.962580 +#> 6: 2020-02-27 R estimate 2.036740 2.036424 0.06555145 1.928347 #> lower_50 lower_20 upper_20 upper_50 upper_90 -#> 1: 2.121108 2.179966 2.251215 2.319625 2.463337 -#> 2: 2.103825 2.154804 2.213847 2.267385 2.395556 -#> 3: 2.081757 2.123637 2.174697 2.216957 2.325451 -#> 4: 2.055113 2.091263 2.133835 2.167701 2.262721 -#> 5: 2.024078 2.053914 2.092803 2.120727 2.202303 -#> 6: 1.988182 2.016474 2.051127 2.074294 2.145812 +#> 1: 2.132905 2.188984 2.268706 2.332697 2.484471 +#> 2: 2.113308 2.162047 2.228660 2.280186 2.402885 +#> 3: 2.091295 2.132324 2.184877 2.228413 2.327439 +#> 4: 2.063099 2.098374 2.140263 2.176206 2.266145 +#> 5: 2.028392 2.060083 2.097677 2.124936 2.203185 +#> 6: 1.992552 2.020233 2.053117 2.077413 2.146660 ``` Reported cases are returned in a separate data frame in order to @@ -303,19 +303,19 @@ streamline the reporting of forecasts and for model evaluation. ``` r head(summary(estimates, output = "estimated_reported_cases")) #> date type median mean sd lower_90 lower_50 lower_20 -#> 1: 2020-02-22 gp_rt 66 68.2645 19.68574 40.00 54 61 -#> 2: 2020-02-23 gp_rt 77 78.8550 21.40584 48.00 64 71 -#> 3: 2020-02-24 gp_rt 76 77.8660 21.55264 47.00 63 71 -#> 4: 2020-02-25 gp_rt 73 75.0555 20.84360 46.00 60 67 -#> 5: 2020-02-26 gp_rt 78 80.2720 22.22163 48.95 65 73 -#> 6: 2020-02-27 gp_rt 111 112.9485 29.56885 70.95 93 103 +#> 1: 2020-02-22 gp_rt 66 67.7370 17.87953 42.00 55.00 62 +#> 2: 2020-02-23 gp_rt 76 77.7610 20.41899 47.00 63.00 71 +#> 3: 2020-02-24 gp_rt 77 77.6055 20.03036 47.00 64.00 72 +#> 4: 2020-02-25 gp_rt 74 75.5260 19.66621 47.00 61.00 70 +#> 5: 2020-02-26 gp_rt 78 79.4530 20.99436 48.95 64.00 73 +#> 6: 2020-02-27 gp_rt 111 113.4765 28.82705 71.95 93.75 103 #> upper_20 upper_50 upper_90 -#> 1: 71 80 103.00 -#> 2: 82 91 116.05 -#> 3: 81 91 117.00 -#> 4: 78 88 112.00 -#> 5: 83 93 120.00 -#> 6: 118 131 165.00 +#> 1: 70.0 79 102.00 +#> 2: 81.0 91 112.00 +#> 3: 82.0 90 113.00 +#> 4: 79.0 87 110.00 +#> 5: 83.0 93 117.00 +#> 6: 117.4 131 164.05 ``` A range of plots are returned (with the single summary plot shown @@ -364,19 +364,19 @@ estimates <- regional_epinow( gp = NULL, stan = stan_opts(cores = 4, warmup = 250, samples = 1000) ) -#> INFO [2023-10-23 16:31:10] Producing following optional outputs: regions, summary, samples, plots, latest -#> INFO [2023-10-23 16:31:10] Reporting estimates using data up to: 2020-04-21 -#> INFO [2023-10-23 16:31:11] No target directory specified so returning output -#> INFO [2023-10-23 16:31:11] Producing estimates for: testland, realland -#> INFO [2023-10-23 16:31:11] Regions excluded: none -#> INFO [2023-10-23 16:31:59] Completed estimates for: testland -#> INFO [2023-10-23 16:32:51] Completed estimates for: realland -#> INFO [2023-10-23 16:32:51] Completed regional estimates -#> INFO [2023-10-23 16:32:51] Regions with estimates: 2 -#> INFO [2023-10-23 16:32:51] Regions with runtime errors: 0 -#> INFO [2023-10-23 16:32:51] Producing summary -#> INFO [2023-10-23 16:32:51] No summary directory specified so returning summary output -#> INFO [2023-10-23 16:32:51] No target directory specified so returning timings +#> INFO [2023-10-26 15:41:15] Producing following optional outputs: regions, summary, samples, plots, latest +#> INFO [2023-10-26 15:41:15] Reporting estimates using data up to: 2020-04-21 +#> INFO [2023-10-26 15:41:15] No target directory specified so returning output +#> INFO [2023-10-26 15:41:15] Producing estimates for: testland, realland +#> INFO [2023-10-26 15:41:15] Regions excluded: none +#> INFO [2023-10-26 15:42:07] Completed estimates for: testland +#> INFO [2023-10-26 15:42:55] Completed estimates for: realland +#> INFO [2023-10-26 15:42:55] Completed regional estimates +#> INFO [2023-10-26 15:42:55] Regions with estimates: 2 +#> INFO [2023-10-26 15:42:55] Regions with runtime errors: 0 +#> INFO [2023-10-26 15:42:55] Producing summary +#> INFO [2023-10-26 15:42:55] No summary directory specified so returning summary output +#> INFO [2023-10-26 15:42:55] No target directory specified so returning timings ``` Results from each region are stored in a `regional` list with across @@ -401,8 +401,8 @@ knitr::kable(estimates$summary$summarised_results$table) | Region | New confirmed cases by infection date | Expected change in daily cases | Effective reproduction no. | Rate of growth | Doubling/halving time (days) | |:---------|:--------------------------------------|:-------------------------------|:---------------------------|:------------------------|:-----------------------------| -| realland | 2149 (1092 – 4175) | Likely decreasing | 0.86 (0.61 – 1.2) | -0.032 (-0.097 – 0.031) | -21 (22 – -7.2) | -| testland | 2128 (1116 – 4273) | Likely decreasing | 0.85 (0.62 – 1.2) | -0.033 (-0.095 – 0.033) | -21 (21 – -7.3) | +| realland | 2141 (1156 – 3975) | Likely decreasing | 0.86 (0.63 – 1.1) | -0.031 (-0.092 – 0.029) | -22 (24 – -7.5) | +| testland | 2084 (1161 – 3800) | Likely decreasing | 0.85 (0.63 – 1.1) | -0.035 (-0.091 – 0.024) | -20 (29 – -7.6) | A range of plots are again returned (with the single summary plot shown below). diff --git a/man/backcalc_opts.Rd b/man/backcalc_opts.Rd index 89cd41520..030e94580 100644 --- a/man/backcalc_opts.Rd +++ b/man/backcalc_opts.Rd @@ -30,7 +30,7 @@ average to use when estimating Rt. This must be odd so that the central estimate is included.} } \value{ -A list of back calculation settings. +A \verb{} object of back calculation settings. } \description{ \ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} diff --git a/man/check_reports_valid.Rd b/man/check_reports_valid.Rd new file mode 100644 index 000000000..e87f74f24 --- /dev/null +++ b/man/check_reports_valid.Rd @@ -0,0 +1,34 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/checks.R +\name{check_reports_valid} +\alias{check_reports_valid} +\title{Validate data input} +\usage{ +check_reports_valid(reports, model) +} +\arguments{ +\item{reports}{A data frame with either: +\itemize{ +\item a minimum of two columns: \code{date} and \code{confirm}, if to be +used by \code{\link[=estimate_infections]{estimate_infections()}} or \code{\link[=estimate_truncation]{estimate_truncation()}}, or +\item a minimum of three columns: \code{date}, \code{primary}, and \code{secondary}, if to be +used by \code{\link[=estimate_secondary]{estimate_secondary()}}. +}} + +\item{model}{The EpiNow2 model to be used. Either +"estimate_infections", "estimate_truncation", or "estimate_secondary". +This is used to determine which checks to perform on the data input.} +} +\value{ +Called for its side effects. +} +\description{ +\code{check_reports_valid()} checks that the supplied data is a \verb{}, +and that it has the right column names and types. In particular, it checks +that the date column is in date format and does not contain NA's, and that +the other columns are numeric. +} +\author{ +James M. Azam +} +\keyword{internal} diff --git a/man/delay_opts.Rd b/man/delay_opts.Rd index 643f60fdb..970635e71 100644 --- a/man/delay_opts.Rd +++ b/man/delay_opts.Rd @@ -15,7 +15,7 @@ using \code{\link[=dist_spec]{dist_spec()}}. Default is an empty call to \code{\ \item{fixed}{deprecated; use \code{dist} instead} } \value{ -A list summarising the input delay distributions. +A \verb{} object summarising the input delay distributions. } \description{ \ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} diff --git a/man/figures/unnamed-chunk-15-1.png b/man/figures/unnamed-chunk-15-1.png index fdd72077e..5beed44bb 100644 Binary files a/man/figures/unnamed-chunk-15-1.png and b/man/figures/unnamed-chunk-15-1.png differ diff --git a/man/figures/unnamed-chunk-19-1.png b/man/figures/unnamed-chunk-19-1.png index cf5141eae..f2673e40a 100644 Binary files a/man/figures/unnamed-chunk-19-1.png and b/man/figures/unnamed-chunk-19-1.png differ diff --git a/man/generation_time_opts.Rd b/man/generation_time_opts.Rd index ff83c6690..b86907108 100644 --- a/man/generation_time_opts.Rd +++ b/man/generation_time_opts.Rd @@ -34,7 +34,8 @@ model option. Use the \code{weigh_delay_priors} argument of \code{estimate_infec instead.} } \value{ -A list summarising the input delay distributions. +A \verb{} object summarising the input delay +distributions. } \description{ \ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} diff --git a/man/gp_opts.Rd b/man/gp_opts.Rd index 36afb1ffd..edbcc587a 100644 --- a/man/gp_opts.Rd +++ b/man/gp_opts.Rd @@ -55,7 +55,7 @@ kernel ("matern", with \code{matern_type = 3/2}). Defaulting to the Matern 3 ove Currently only the Matern 3/2 kernel is supported.} } \value{ -A list of settings defining the Gaussian process +A \verb{} object of settings defining the Gaussian process } \description{ \ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} diff --git a/man/obs_opts.Rd b/man/obs_opts.Rd index bce8c6edf..f9c617dd5 100644 --- a/man/obs_opts.Rd +++ b/man/obs_opts.Rd @@ -45,7 +45,7 @@ included in the model.} be returned by the model.} } \value{ -A list of observation model settings. +An \verb{} object of observation model settings. } \description{ \ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} diff --git a/man/rt_opts.Rd b/man/rt_opts.Rd index 8fa465e2c..7e828d572 100644 --- a/man/rt_opts.Rd +++ b/man/rt_opts.Rd @@ -55,7 +55,8 @@ the population that is susceptible. When set to 0 no population adjustment is done.} } \value{ -A list of settings defining the time-varying reproduction number. +An \verb{} object with settings defining the time-varying +reproduction number. } \description{ \ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} diff --git a/man/secondary_opts.Rd b/man/secondary_opts.Rd index 5071dec08..31bb22242 100644 --- a/man/secondary_opts.Rd +++ b/man/secondary_opts.Rd @@ -24,14 +24,14 @@ hospital admissions. options that can be passed.} } \value{ -A list of binary options summarising secondary model used in -\code{estimate_secondary()}. Options returned are \code{cumulative} (should the -secondary report be cumulative), \code{historic} (should a convolution of primary -reported cases be used to predict secondary reported cases), -\code{primary_hist_additive} (should the historic convolution of primary reported -cases be additive or subtractive), \code{current} (should currently observed -primary reported cases contribute to current secondary reported cases), -\code{primary_current_additive} (should current primary reported cases be +A \verb{} object of binary options summarising secondary +model used in \code{estimate_secondary()}. Options returned are \code{cumulative} +(should the secondary report be cumulative), \code{historic} (should a +convolution of primary reported cases be used to predict secondary reported +cases), \code{primary_hist_additive} (should the historic convolution of primary +reported cases be additive or subtractive), \code{current} (should currently +observed primary reported cases contribute to current secondary reported +cases), \code{primary_current_additive} (should current primary reported cases be additive or subtractive). } \description{ diff --git a/man/stan_opts.Rd b/man/stan_opts.Rd index ffa6b5b54..c6d02adba 100644 --- a/man/stan_opts.Rd +++ b/man/stan_opts.Rd @@ -38,7 +38,8 @@ returned.} \item{...}{Additional parameters to pass underlying option functions.} } \value{ -A list of arguments to pass to the appropriate rstan functions. +A \verb{} object of arguments to pass to the appropriate +rstan functions. } \description{ \ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} diff --git a/man/trunc_opts.Rd b/man/trunc_opts.Rd index 8dba374b5..504078016 100644 --- a/man/trunc_opts.Rd +++ b/man/trunc_opts.Rd @@ -12,7 +12,8 @@ the truncation generated using \code{\link[=dist_spec]{dist_spec()}} or \code{\l Default is an empty call to \code{\link[=dist_spec]{dist_spec()}}, i.e. no truncation} } \value{ -A list summarising the input truncation distribution. +A \verb{} object summarising the input truncation +distribution. } \description{ \ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#stable}{\figure{lifecycle-stable.svg}{options: alt='[Stable]'}}}{\strong{[Stable]}} diff --git a/tests/testthat/test-checks.R b/tests/testthat/test-checks.R new file mode 100644 index 000000000..e9670a6fc --- /dev/null +++ b/tests/testthat/test-checks.R @@ -0,0 +1,143 @@ +# Setup for testing ------------------------------------------------------- + +futile.logger::flog.threshold("FATAL") + +# Create reports reports data for estimate_infection() +est_inf <- EpiNow2::example_confirmed[1:10] + +# Create reports reports data for estimate_secondary() +est_sec <- data.table::copy(est_inf)[ + , + `:=`( + primary = confirm, + secondary = round(0.5 * confirm), + confirm = NULL + ) +] + +# Custom test functions --------------------------------------------------- +test_col_specs <- function(dt_list, model = "estimate_infections") { + expect_error( + check_reports_valid(dt_list$bad_col_name, + for_estimate_secondary = for_estimate_secondary + ) + ) + expect_error( + check_reports_valid(dt_list$bad_col_type, + for_estimate_secondary = for_estimate_secondary + ) + ) + expect_error( + check_reports_valid(dt_list$bad_col_entry, + for_estimate_secondary = for_estimate_secondary + ) + ) +} + +test_that("check_reports_valid errors for bad 'confirm' specifications", { + # Bad "confirm" column spec scenarios + confirm_col_dt <- list( + # Bad column name + bad_col_name = data.table::copy(est_inf)[ + , + `:=`( + confirm_bad_name = confirm, + confirm = NULL + ) + ], + # Bad column type + bad_col_type = data.table::copy(est_inf)[ + , + lapply(.SD, as.character), + by = confirm + ], + # Bad column entry + bad_col_entry = data.table::copy(est_inf)[ + , + confirm := -confirm + ] + ) + # Run tests + test_col_specs(confirm_col_dt, model = "estimate_infections") +}) + +test_that("check_reports_valid errors for bad 'date' specifications", { + # Bad "date" column spec scenarios + date_col_dt <- list( + # Bad column name + bad_col_name = data.table::copy(est_inf)[ + , + `:=`( + date_bad_name = date, + date = NULL + ) + ], + # Bad column type + bad_col_type = data.table::copy(est_inf)[ + , + lapply(.SD, as.character), + by = date + ], + # Bad column entry + bad_col_entry = data.table::copy(est_inf)[ + c(1, 3), + date := NA + ] + ) + # Run tests + test_col_specs(date_col_dt, model = "estimate_infections") +}) + +test_that("check_reports_valid errors for bad 'primary' specifications", { + # Bad "primary" column spec scenarios + primary_col_dt <- list( + # Bad column name + bad_col_name = data.table::copy(est_sec)[ + , + `:=`( + primary_bad_name = primary, + primary = NULL + ) + ], + # Bad column type + bad_col_type = data.table::copy(est_sec)[ + , + lapply(.SD, as.character), + by = primary + ], + # Bad column entry + bad_col_entry = data.table::copy(est_sec)[ + , + primary := -primary + ] + ) + # Run tests + test_col_specs(primary_col_dt, model = "estimate_secondary") +}) + +test_that("check_reports_valid errors for bad 'secondary' specifications", { + # Bad "secondary" column spec scenarios + secondary_col_dt <- list( + # Bad column name + bad_col_name = data.table::copy(est_sec)[ + , + `:=`( + secondary_bad_name = primary, + secondary = NULL + ) + ], + # Bad column type + bad_col_type = data.table::copy(est_sec)[ + , + lapply(.SD, as.character), + by = secondary + ], + # Bad column entry + bad_col_entry = data.table::copy(est_sec)[ + , + secondary := -secondary + ] + ) + # Run tests + test_col_specs(secondary_col_dt, model = "estimate_secondary") +}) diff --git a/tests/testthat/test-estimate_infections.R b/tests/testthat/test-estimate_infections.R index f574f3e1a..69303de5d 100644 --- a/tests/testthat/test-estimate_infections.R +++ b/tests/testthat/test-estimate_infections.R @@ -7,12 +7,12 @@ reported_cases <- EpiNow2::example_confirmed[1:30] default_estimate_infections <- function(..., add_stan = list(), delay = TRUE) { futile.logger::flog.threshold("FATAL") - def_stan <- stan_opts( + def_stan <- list( chains = 2, warmup = 50, samples = 50, control = list(adapt_delta = 0.8) ) - stan_args <- def_stan[setdiff(names(def_stan), names(add_stan))] - stan_args <- c(stan_args, add_stan) + def_stan <- modifyList(def_stan, add_stan) + stan_args <- do.call(stan_opts, def_stan) suppressWarnings(estimate_infections(..., generation_time = generation_time_opts(example_generation_time), diff --git a/tests/testthat/test-estimate_secondary.R b/tests/testthat/test-estimate_secondary.R index 880ccd285..a887c2381 100644 --- a/tests/testthat/test-estimate_secondary.R +++ b/tests/testthat/test-estimate_secondary.R @@ -105,7 +105,7 @@ test_that("estimate_secondary works with weigh_delay_priors = TRUE", { mean = 2.5, mean_sd = 0.5, sd = 0.47, sd_sd = 0.25, max = 30 ) inc_weigh <- estimate_secondary( - cases[1:60], delays = delays, + cases[1:60], delays = delay_opts(delays), obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE), weigh_delay_priors = TRUE, verbose = FALSE )