Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add input validation #476

Merged
merged 59 commits into from
Nov 14, 2023
Merged
Show file tree
Hide file tree
Changes from 57 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
de54916
Add function to check validity of reports input
jamesmbaazam Oct 12, 2023
242c4f4
Adopt reports checking function
jamesmbaazam Oct 12, 2023
73d4817
Import checkmate
jamesmbaazam Oct 12, 2023
6293fa7
Linting: remove whitespace
jamesmbaazam Oct 12, 2023
c486b10
Remove explicit namespacing
jamesmbaazam Oct 13, 2023
83f28db
Import checkmate functions
jamesmbaazam Oct 13, 2023
5559a55
Revise function description
jamesmbaazam Oct 13, 2023
de2b6e5
Make _opts functions return their own class
jamesmbaazam Oct 13, 2023
650271b
Import functions for use
jamesmbaazam Oct 13, 2023
04f8ff3
Add input checking to the estimation functions
jamesmbaazam Oct 13, 2023
4af70cc
Return an object of class secondary_opts
jamesmbaazam Oct 13, 2023
b954ac2
Allow for NULL inputs
jamesmbaazam Oct 19, 2023
0544e8f
Change burn_in check from integer to numeric
jamesmbaazam Oct 19, 2023
d47606b
Update returned objects
jamesmbaazam Oct 19, 2023
d10e236
Allow NULL args
jamesmbaazam Oct 19, 2023
2e50956
Make lower bound of horizon 0
jamesmbaazam Oct 19, 2023
d291b12
Don't import assert_integer
jamesmbaazam Oct 19, 2023
c4064c3
Pass delays to delay_opts
jamesmbaazam Oct 23, 2023
ee247d9
Don't return a class
jamesmbaazam Oct 23, 2023
30c6deb
Catch the dot-dot-dot and modify iter if present
jamesmbaazam Oct 23, 2023
ba00fc1
Convert argument to stan_opts objects
jamesmbaazam Oct 23, 2023
7cd4bc4
Revise wording of return type
jamesmbaazam Oct 23, 2023
b50c450
Remove import of assert_integer
jamesmbaazam Oct 23, 2023
8af25bb
Linting: Remove trailing whitespace
jamesmbaazam Oct 23, 2023
0987274
Linting: trim white space
jamesmbaazam Oct 24, 2023
65f31d3
Add tests for check_reports_valid function
jamesmbaazam Oct 26, 2023
e234eec
Automatic readme update
actions-user Oct 26, 2023
6187da7
Check inputs of simulate_infection()
jamesmbaazam Oct 26, 2023
164a4e3
Check inputs of epinow function
jamesmbaazam Oct 26, 2023
30befbd
Check that target_date is a string, not a date
jamesmbaazam Oct 27, 2023
9f70cfd
Check that target_folder is specified as a directory
jamesmbaazam Oct 27, 2023
8b52163
Import functions for assertions
jamesmbaazam Oct 27, 2023
e360309
Check that estimates input contains a "fit" element
jamesmbaazam Oct 27, 2023
6c56694
Various checks on "R" input
jamesmbaazam Oct 27, 2023
d55876a
Fix docs of check_reports_valid
jamesmbaazam Oct 30, 2023
dd422e3
Fix docs of check_reports_valid
jamesmbaazam Oct 30, 2023
82839d8
Fix docs of check_reports_valid
jamesmbaazam Oct 30, 2023
a6f2e89
Fix docs of check_reports_valid
jamesmbaazam Oct 30, 2023
c2c330f
Error if column names are outside of allowed set
jamesmbaazam Oct 30, 2023
34f8dd7
Remove boolean argument in favour of enumerated options
jamesmbaazam Oct 30, 2023
688190d
Generate fixed doc
jamesmbaazam Oct 30, 2023
1b5ce1f
Update docs to mention returned class
jamesmbaazam Oct 30, 2023
d47d55e
Replace binary argument with enumerated argument
jamesmbaazam Oct 30, 2023
dbeeae5
Don't allow null values
jamesmbaazam Oct 30, 2023
8323b68
Revert to have a minimal set of allowed columns
jamesmbaazam Oct 31, 2023
d908afc
Improve documentation
jamesmbaazam Oct 31, 2023
20e4f25
Allow model argument to be null.
jamesmbaazam Oct 31, 2023
e06d1f4
Improve typesetting
jamesmbaazam Oct 31, 2023
61df073
Styling
jamesmbaazam Oct 31, 2023
7a61847
Fix a wrong function link in docs
jamesmbaazam Oct 31, 2023
bb89302
Remove awkward line breaks
jamesmbaazam Nov 1, 2023
46fcf69
Use rlang arg_match() to prevent partial match
jamesmbaazam Nov 1, 2023
5410759
Add news item
jamesmbaazam Nov 1, 2023
db0bfce
Pass arguments to `stan_opts` as list in tests
sbfnk Nov 13, 2023
9b94515
Linting: Use == to match length-1 scalars, not %in%
jamesmbaazam Nov 13, 2023
7ccbc70
Linting: Switch logic of ifelse statement
jamesmbaazam Nov 13, 2023
3931cbc
Linting: contruct paths with file.path() instead of paste0.
jamesmbaazam Nov 13, 2023
841d340
Linting: Fix indentation
jamesmbaazam Nov 14, 2023
da6a20d
Fix wrong return value
jamesmbaazam Nov 14, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ BugReports: https://github.com/epiforecasts/EpiNow2/issues
Depends:
R (>= 3.5.0)
Imports:
checkmate,
data.table,
futile.logger (>= 1.4),
future,
Expand Down
14 changes: 14 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,20 @@ export(update_secondary_args)
import(Rcpp)
import(methods)
import(rstantools)
importFrom(R.utils,isDirectory)
importFrom(R.utils,withTimeout)
importFrom(checkmate,assert_character)
importFrom(checkmate,assert_class)
importFrom(checkmate,assert_data_frame)
importFrom(checkmate,assert_date)
importFrom(checkmate,assert_integerish)
importFrom(checkmate,assert_logical)
importFrom(checkmate,assert_names)
importFrom(checkmate,assert_numeric)
importFrom(checkmate,assert_path_for_output)
importFrom(checkmate,assert_string)
importFrom(checkmate,test_data_frame)
importFrom(checkmate,test_numeric)
importFrom(data.table,":=")
importFrom(data.table,.N)
importFrom(data.table,as.data.table)
Expand Down Expand Up @@ -186,6 +199,7 @@ importFrom(purrr,safely)
importFrom(purrr,transpose)
importFrom(purrr,walk)
importFrom(rlang,abort)
importFrom(rlang,arg_match)
importFrom(rlang,cnd_muffle)
importFrom(rlang,warn)
importFrom(rstan,expose_stan_functions)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
* Reduced the number of long-running examples. By @sbfnk in #459 and reviewed by @seabbs.
* Changed all instances of arguments that refer to the maximum of a distribution to reflect the maximum. Previously this did, in some instance, refer to the length of the PMF. By @sbfnk in #468.
* Fixed a bug in the bounds of delays when setting initial conditions. By @sbfnk in #474.
* Added input checking to `estimate_infections()`, `estimate_secondary()`, `estimate_truncation()`, `simulate_infections()`, and `epinow()`. `check_reports_valid()` has been added to validate the reports dataset passed to these functions. Tests are added to check `check_reports_valid()`. As part of input validation, the various `*_opts()` functions now return subclasses of the same name as the functions and are tested against passed arguments to ensure the right `*_opts()` is passed to the right argument. For example, the `obs` argument in `estimate_secondary()` is expected to only receive arguments passed through `obs_opts()` and will error otherwise. By @jamesmbaazam in #476 and reviewed by @sbfnk and @seabbs.

## Model changes

Expand Down
56 changes: 56 additions & 0 deletions R/checks.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
#' Validate data input
#'
#' @description
#' `check_reports_valid()` checks that the supplied data is a `<data.frame>`,
#' and that it has the right column names and types. In particular, it checks
#' that the date column is in date format and does not contain NA's, and that
#' the other columns are numeric.
#'
#' @param reports A data frame with either:
#' * a minimum of two columns: `date` and `confirm`, if to be
#' used by [estimate_infections()] or [estimate_truncation()], or
#' * a minimum of three columns: `date`, `primary`, and `secondary`, if to be
#' used by [estimate_secondary()].
#' @param model The EpiNow2 model to be used. Either
#' "estimate_infections", "estimate_truncation", or "estimate_secondary".
#' This is used to determine which checks to perform on the data input.
#' @importFrom checkmate assert_data_frame assert_date assert_names
#' assert_numeric
#' @importFrom rlang arg_match
#' @return Called for its side effects.
#' @author James M. Azam
#' @keywords internal
check_reports_valid <- function(reports, model) {
# Check that the case time series (reports) is a data frame
assert_data_frame(reports)
# Perform checks depending on the model to the data is meant to be used with
model <- arg_match(
model,
values = c(
"estimate_infections",
"estimate_truncation",
"estimate_secondary"
)
)

if (model == "estimate_secondary") {
# Check that reports has the right column names
assert_names(
names(reports),
must.include = c("date", "primary", "secondary")
)
# Check that the reports data.frame has the right column types
assert_date(reports$date, any.missing = FALSE)
assert_numeric(reports$primary, lower = 0)
assert_numeric(reports$secondary, lower = 0)
} else {
# Check that reports has the right column names
assert_names(
names(reports),
must.include = c("date", "confirm")
)
# Check that the reports data.frame has the right column types
assert_date(reports$date, any.missing = FALSE)
assert_numeric(reports$confirm, lower = 0)
}
}
10 changes: 5 additions & 5 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -170,9 +170,9 @@ create_future_rt <- function(future = "latest", delay = 0) {
"estimate"
)
)
if (!(future %in% "project")) {
if (!(future == "project")) {
out$fixed <- TRUE
out$from <- ifelse(future %in% "latest", 0, -delay)
out$from <- ifelse(future == "latest", 0, -delay)
}
} else if (is.numeric(future)) {
out$fixed <- TRUE
Expand Down Expand Up @@ -227,7 +227,7 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL,
# apply random walk
if (rt$rw != 0) {
breakpoints <- as.integer(seq_along(breakpoints) %% rt$rw == 0)
if (!(rt$future %in% "project")) {
if (!(rt$future == "project")) {
max_bps <- length(breakpoints) - horizon + future_rt$from
if (max_bps < length(breakpoints)) {
breakpoints[(max_bps + 1):length(breakpoints)] <- 0
Expand All @@ -248,7 +248,7 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL,
future_fixed = as.numeric(future_rt$fixed),
fixed_from = future_rt$from,
pop = rt$pop,
stationary = as.numeric(rt$gp_on %in% "R0"),
stationary = as.numeric(rt$gp_on == "R0"),
future_time = horizon - future_rt$from
)
return(rt_data)
Expand Down Expand Up @@ -383,7 +383,7 @@ create_gp_data <- function(gp = gp_opts(), data) {
#' create_obs_model(obs_opts(week_length = 3), dates = dates)
create_obs_model <- function(obs = obs_opts(), dates) {
data <- list(
model_type = as.numeric(obs$family %in% "negbin"),
model_type = as.numeric(obs$family == "negbin"),
phi_mean = obs$phi[1],
phi_sd = obs$phi[2],
week_effect = ifelse(obs$week_effect, obs$week_length, 1),
Expand Down
22 changes: 18 additions & 4 deletions R/epinow.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
#' @importFrom lubridate days
#' @importFrom futile.logger flog.fatal flog.warn flog.error flog.debug ftry
#' @importFrom rlang cnd_muffle
#' @importFrom checkmate assert_string assert_path_for_output
#' assert_date assert_logical
#' @importFrom R.utils isDirectory
#' @author Sam Abbott
#' @examples
#' \donttest{
Expand Down Expand Up @@ -105,6 +108,17 @@ epinow <- function(reported_cases,
plot_args = list(),
target_folder = NULL, target_date,
logs = tempdir(), id = "epinow", verbose = interactive()) {
# Check inputs
assert_logical(return_output)
stopifnot("target_folder is not a directory" =
jamesmbaazam marked this conversation as resolved.
Show resolved Hide resolved
!is.null(target_folder) || isDirectory(target_folder)
)
if (!missing(target_date)) {
assert_string(target_date)
}
assert_string(id)
assert_logical(verbose)

if (is.null(target_folder)) {
return_output <- TRUE
}
Expand Down Expand Up @@ -251,7 +265,7 @@ epinow <- function(reported_cases,
}
),
error = function(e) {
if (id %in% "epinow") {
if (id == "epinow") {
stop(e)
} else {
error_text <- sprintf("%s: %s - %s", id, e$message, toString(e$call))
Expand All @@ -269,15 +283,15 @@ epinow <- function(reported_cases,
}

if (!is.null(target_folder) && !is.null(out$error)) {
saveRDS(out$error, paste0(target_folder, "/error.rds"))
saveRDS(out$trace, paste0(target_folder, "/trace.rds"))
saveRDS(out$error, file.path(target_folder, "error.rds"))
saveRDS(out$trace, file.path(target_folder, "trace.rds"))
}

# log timing if specified
if (output["timing"]) {
out$timing <- round(as.numeric(end_time - start_time), 1)
if (!is.null(target_folder)) {
saveRDS(out$timing, paste0(target_folder, "/runtime.rds"))
saveRDS(out$timing, file.path(target_folder, "runtime.rds"))
}
}

Expand Down
56 changes: 38 additions & 18 deletions R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@
#' @importFrom lubridate days
#' @importFrom purrr transpose
#' @importFrom futile.logger flog.threshold flog.warn flog.debug
#' @importFrom checkmate assert_class assert_numeric assert_logical
#' assert_string
#' @examples
#' \donttest{
#' # set number of cores to use
Expand Down Expand Up @@ -131,6 +133,24 @@ estimate_infections <- function(reported_cases,
weigh_delay_priors = TRUE,
id = "estimate_infections",
verbose = interactive()) {
# Validate inputs
check_reports_valid(reported_cases, model = "estimate_infections")
assert_class(generation_time, "generation_time_opts")
assert_class(delays, "delay_opts")
assert_class(truncation, "trunc_opts")
assert_class(rt, "rt_opts", null.ok = TRUE)
assert_class(backcalc, "backcalc_opts")
assert_class(gp, "gp_opts", null.ok = TRUE)
assert_class(obs, "obs_opts")
assert_class(stan, "stan_opts")
assert_numeric(horizon, lower = 0)
assert_numeric(CrIs, lower = 0, upper = 1)
assert_logical(filter_leading_zeros)
assert_numeric(zero_threshold, lower = 0)
assert_logical(weigh_delay_priors)
assert_string(id)
assert_logical(verbose)

set_dt_single_thread()

# store dirty reported case data
Expand Down Expand Up @@ -211,7 +231,7 @@ estimate_infections <- function(reported_cases,
# Initialise fitting by using a previous fit or fitting to cumulative cases
if (!is.null(args$init_fit)) {
if (!inherits(args$init_fit, "stanfit") &&
args$init_fit %in% "cumulative") {
args$init_fit == "cumulative") {
args$init_fit <- init_cumulative_fit(args,
warmup = 50, samples = 50,
id = id, verbose = FALSE
Expand Down Expand Up @@ -435,26 +455,15 @@ fit_model_with_nuts <- function(args, future = FALSE, max_execution_time = Inf,
}
}

if (!future) {
fit <- fit_chain(1,
stan_args = args, max_time = max_execution_time,
catch = !id %in% c("estimate_infections", "epinow")
)
if (stuck_chains > 0) {
fit <- NULL
}
if (is.null(fit)) {
rlang::abort("model fitting was timed out or failed")
}
} else {
if (future) {
chains <- args$chains
args$chains <- 1
args$cores <- 1
fits <- future.apply::future_lapply(1:chains, fit_chain,
stan_args = args,
max_time = max_execution_time,
catch = TRUE,
future.seed = TRUE
stan_args = args,
max_time = max_execution_time,
catch = TRUE,
future.seed = TRUE
sbfnk marked this conversation as resolved.
Show resolved Hide resolved
)
if (stuck_chains > 0) {
fits[1:stuck_chains] <- NULL
Expand All @@ -478,12 +487,23 @@ fit_model_with_nuts <- function(args, future = FALSE, max_execution_time = Inf,
if ((chains - failed_chains) < 2) {
rlang::abort(
"model fitting failed as too few chains were returned to assess",
" convergence (2 or more required)"
" convergence (2 or more required)"
)
}
}
fit <- rstan::sflist2stanfit(fit)
}
} else {
fit <- fit_chain(1,
stan_args = args, max_time = max_execution_time,
catch = !id %in% c("estimate_infections", "epinow")
)
if (stuck_chains > 0) {
fit <- NULL
}
if (is.null(fit)) {
rlang::abort("model fitting was timed out or failed")
}
}
return(fit)
}
Expand Down
36 changes: 26 additions & 10 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,8 @@
#' @importFrom lubridate wday
#' @importFrom data.table as.data.table merge.data.table
#' @importFrom utils modifyList
#' @importFrom checkmate assert_class assert_numeric assert_data_frame
#' assert_logical
#' @examples
#' \donttest{
#' # set number of cores to use
Expand Down Expand Up @@ -149,6 +151,19 @@ estimate_secondary <- function(reports,
weigh_delay_priors = FALSE,
verbose = interactive(),
...) {
# Validate the inputs
check_reports_valid(reports, model = "estimate_secondary")
assert_class(secondary, "secondary_opts")
assert_class(delays, "delay_opts")
assert_class(truncation, "trunc_opts")
assert_class(obs, "obs_opts")
assert_numeric(burn_in, lower = 0)
assert_numeric(CrIs, lower = 0, upper = 1)
assert_data_frame(priors, null.ok = TRUE)
assert_class(model, "stanfit", null.ok = TRUE)
assert_logical(weigh_delay_priors)
assert_logical(verbose)

reports <- data.table::as.data.table(reports)

if (burn_in >= nrow(reports)) {
Expand Down Expand Up @@ -238,14 +253,14 @@ estimate_secondary <- function(reports,
#' options that can be passed.
#'
#' @seealso estimate_secondary
#' @return A list of binary options summarising secondary model used in
#' `estimate_secondary()`. Options returned are `cumulative` (should the
#' secondary report be cumulative), `historic` (should a convolution of primary
#' reported cases be used to predict secondary reported cases),
#' `primary_hist_additive` (should the historic convolution of primary reported
#' cases be additive or subtractive), `current` (should currently observed
#' primary reported cases contribute to current secondary reported cases),
#' `primary_current_additive` (should current primary reported cases be
#' @return A `<secondary_opts>` object of binary options summarising secondary
#' model used in `estimate_secondary()`. Options returned are `cumulative`
#' (should the secondary report be cumulative), `historic` (should a
#' convolution of primary reported cases be used to predict secondary reported
#' cases), `primary_hist_additive` (should the historic convolution of primary
#' reported cases be additive or subtractive), `current` (should currently
#' observed primary reported cases contribute to current secondary reported
#' cases), `primary_current_additive` (should current primary reported cases be
#' additive or subtractive).
#'
#' @export
Expand All @@ -258,15 +273,15 @@ estimate_secondary <- function(reports,
#' secondary_opts("prevalence")
secondary_opts <- function(type = "incidence", ...) {
type <- match.arg(type, choices = c("incidence", "prevalence"))
if (type %in% "incidence") {
if (type == "incidence") {
data <- list(
cumulative = 0,
historic = 1,
primary_hist_additive = 1,
current = 0,
primary_current_additive = 0
)
} else if (type %in% "prevalence") {
} else if (type == "prevalence") {
data <- list(
cumulative = 1,
historic = 1,
Expand All @@ -276,6 +291,7 @@ secondary_opts <- function(type = "incidence", ...) {
)
}
data <- modifyList(data, list(...))
attr(data, "class") <- c("secondary_opts", class(data))
jamesmbaazam marked this conversation as resolved.
Show resolved Hide resolved
return(data)
}

Expand Down
Loading
Loading