diff --git a/DESCRIPTION b/DESCRIPTION index 3ceccaa9..7269e316 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -38,7 +38,8 @@ LazyData: true RoxygenNote: 7.2.3 Depends: R (>= 3.5.0) -Imports: +Imports: + checkmate, methods, Rcpp (>= 0.12.0), RcppParallel (>= 5.0.1), diff --git a/R/modelling.R b/R/modelling.R index 8558f109..b0bdd3fa 100644 --- a/R/modelling.R +++ b/R/modelling.R @@ -49,13 +49,15 @@ #' } #' @export run_seromodel <- function(serodata, - foi_model = "constant", + foi_model = c("constant", "tv_normal_log", + "tv_normal"), n_iters = 1000, n_thin = 2, delta = 0.90, m_treed = 10, decades = 0, print_summary = TRUE) { + foi_model <- match.arg(foi_model) survey <- unique(serodata$survey) if (length(survey) > 1) warning("You have more than 1 surveys or survey codes") seromodel_object <- fit_seromodel(serodata = serodata, @@ -122,13 +124,15 @@ run_seromodel <- function(serodata, #' #' @export fit_seromodel <- function(serodata, - foi_model, + foi_model = c("constant", "tv_normal_log", + "tv_normal"), n_iters = 1000, n_thin = 2, delta = 0.90, m_treed = 10, decades = 0) { # TODO Add a warning because there are exceptions where a minimal amount of iterations is needed + foi_model <- match.arg(foi_model) model <- stanmodels[[foi_model]] exposure_ages <- get_exposure_ages(serodata) exposure_years <- (min(serodata$birth_year):serodata$tsur[1])[-1] diff --git a/R/seroprevalence_data.R b/R/seroprevalence_data.R index f4168cfe..f0a033db 100644 --- a/R/seroprevalence_data.R +++ b/R/seroprevalence_data.R @@ -37,6 +37,16 @@ prepare_serodata <- function(serodata = serodata, alpha = 0.05, add_age_mean_f = TRUE) { + checkmate::assert_numeric(alpha, lower = 0, upper = 1) + checkmate::assert_logical(add_age_mean_f) + #Check that serodata has the right columns + stopifnot("serodata must contain the right columns" = + setequal(names(serodata), + c("survey", "total", "counts", "age_min","age_max", + "tsur", "country","test", "antibody" + ) + ) + ) if(add_age_mean_f){ serodata <- serodata %>% dplyr::mutate(age_mean_f = floor((age_min + age_max) / 2), sample_size = sum(total)) %>%