diff --git a/DESCRIPTION b/DESCRIPTION index f99447de1..9b1f5896c 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -33,7 +33,9 @@ Imports: ggridges, here, stats, - cli + cli, + checkmate, + rstan Suggests: bookdown, testthat (>= 3.0.0), diff --git a/R/latent_individual.R b/R/latent_individual.R index 8f06d886c..304804c97 100644 --- a/R/latent_individual.R +++ b/R/latent_individual.R @@ -61,11 +61,24 @@ epidist_formula.epidist_latent_individual <- function(data, delay_central = ~ 1, return(form) } +#' @importFrom rstan lookup #' @method epidist_family epidist_latent_individual #' @family latent_individual #' @export epidist_family.epidist_latent_individual <- function(data, family = "lognormal", ...) { + checkmate::assert_string(family) + + pdf_lookup <- rstan::lookup("pdf") + valid_pdfs <- gsub("_lpdf", "", pdf_lookup$StanFunction) + if (!family %in% valid_pdfs) { + cli::cli_warn( + "The provided family {.code family} does not correspond to a valid LPDF + function available in rstan. (It is possible [but unlikely] that there is + such an LPDF in Stan available via cmdstanr, as rstan is behind Stan.)" + ) + } + brms::custom_family( paste0("latent_", family), dpars = c("mu", "sigma"), diff --git a/tests/testthat/test-unit-latent_individual.R b/tests/testthat/test-unit-latent_individual.R index 9b5abe686..34830cb85 100644 --- a/tests/testthat/test-unit-latent_individual.R +++ b/tests/testthat/test-unit-latent_individual.R @@ -55,6 +55,11 @@ test_that("epidist_family.epidist_latent_individual with default settings produc expect_s3_class(family, "family") }) +test_that("epidist_family.epidist_latent_individual warns users or gives an error when passed inappropriate family input", { # nolint: line_length_linter. + expect_error(epidist_family(prep_obs, family = 1)) + expect_warning(epidist_family(prep_obs, family = "not_a_real_lpdf")) +}) + test_that("the family argument in epidist_family.epidist_latent_individual passes as expected", { # nolint: line_length_linter. family_gamma <- epidist_family(prep_obs, family = "gamma") expect_equal(family_gamma$name, "latent_gamma")