From e4f4888a3e6d6e15ca03757ab01fb21c9a7ca3d8 Mon Sep 17 00:00:00 2001 From: James Azam Date: Thu, 5 Oct 2023 14:42:13 +0100 Subject: [PATCH] Add input validation (#92) * Added argument matching in fit_seromodel() * Added allowed options to arg foi_model * Added input validation to prepare_serodata() * Fixed the check for required columns * Regenerated the docs --------- Co-authored-by: ntorresd --- DESCRIPTION | 3 ++- R/modelling.R | 8 ++++++-- R/seroprevalence_data.R | 11 +++++++++++ man/fit_seromodel.Rd | 2 +- man/run_seromodel.Rd | 2 +- 5 files changed, 21 insertions(+), 5 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index f2ee2530..68b013bf 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -38,7 +38,8 @@ LazyData: true RoxygenNote: 7.2.3 Depends: R (>= 3.5.0) -Imports: +Imports: + checkmate, methods, Rcpp (>= 0.12.0), RcppParallel (>= 5.0.1), diff --git a/R/modelling.R b/R/modelling.R index 973b699a..c063bdc1 100644 --- a/R/modelling.R +++ b/R/modelling.R @@ -48,13 +48,15 @@ #' foi_model = "constant") #' @export run_seromodel <- function(serodata, - foi_model = "constant", + foi_model = c("constant", "tv_normal_log", + "tv_normal"), n_iters = 1000, n_thin = 2, delta = 0.90, m_treed = 10, decades = 0, print_summary = TRUE) { + foi_model <- match.arg(foi_model) survey <- unique(serodata$survey) if (length(survey) > 1) warning("You have more than 1 surveys or survey codes") seromodel_object <- fit_seromodel(serodata = serodata, @@ -102,13 +104,15 @@ run_seromodel <- function(serodata, #' #' @export fit_seromodel <- function(serodata, - foi_model, + foi_model = c("constant", "tv_normal_log", + "tv_normal"), n_iters = 1000, n_thin = 2, delta = 0.90, m_treed = 10, decades = 0) { # TODO Add a warning because there are exceptions where a minimal amount of iterations is needed + foi_model <- match.arg(foi_model) model <- stanmodels[[foi_model]] cohort_ages <- get_cohort_ages(serodata = serodata) exposure_matrix <- get_exposure_matrix(serodata) diff --git a/R/seroprevalence_data.R b/R/seroprevalence_data.R index d6ab599d..933e8f27 100644 --- a/R/seroprevalence_data.R +++ b/R/seroprevalence_data.R @@ -33,14 +33,25 @@ #' @export prepare_serodata <- function(serodata = serodata, alpha = 0.05) { + checkmate::assert_numeric(alpha, lower = 0, upper = 1) + #Check that serodata has the right columns + stopifnot("serodata must contain the right columns" = + all(c("survey", "total", "counts", "age_min", "age_max", "tsur", + "country","test","antibody" + ) %in% + colnames(serodata) + ) + ) if(!any(colnames(serodata) == "age_mean_f")){ serodata <- serodata %>% dplyr::mutate(age_mean_f = floor((age_min + age_max) / 2), sample_size = sum(total)) } + if(!any(colnames(serodata) == "birth_year")){ serodata <- serodata %>% dplyr::mutate(birth_year = .data$tsur - .data$age_mean_f) } + serodata <- serodata %>% cbind( Hmisc::binconf( diff --git a/man/fit_seromodel.Rd b/man/fit_seromodel.Rd index 12cacbf1..9f17a451 100644 --- a/man/fit_seromodel.Rd +++ b/man/fit_seromodel.Rd @@ -6,7 +6,7 @@ \usage{ fit_seromodel( serodata, - foi_model, + foi_model = c("constant", "tv_normal_log", "tv_normal"), n_iters = 1000, n_thin = 2, delta = 0.9, diff --git a/man/run_seromodel.Rd b/man/run_seromodel.Rd index f335bc4d..9d70340f 100644 --- a/man/run_seromodel.Rd +++ b/man/run_seromodel.Rd @@ -6,7 +6,7 @@ \usage{ run_seromodel( serodata, - foi_model = "constant", + foi_model = c("constant", "tv_normal_log", "tv_normal"), n_iters = 1000, n_thin = 2, delta = 0.9,