From a9b339ac94f9b52a4426d15c696bd28866ddea2c Mon Sep 17 00:00:00 2001 From: Adam Howes Date: Mon, 21 Oct 2024 17:37:41 +0100 Subject: [PATCH] Issue #386: Add direct model allowing pass through to `brms` (#393) * Template for functions required by direct model * Template for direct_model tests * First try at complete version of direct model (without documentation) * Document and lint, plus move default formula to delay ~ . * Update pkgdown * Document following merge --- NAMESPACE | 4 ++ R/direct_model.R | 68 +++++++++++++++++++ R/formula.R | 3 + _pkgdown.yml | 4 ++ man/as_direct_model.Rd | 30 ++++++++ man/as_latent_individual.Rd | 2 + ..._family_model.epidist_latent_individual.Rd | 1 + ...formula_model.epidist_latent_individual.Rd | 1 + man/epidist_validate.epidist_direct_model.Rd | 24 +++++++ ...dist_validate.epidist_latent_individual.Rd | 1 + man/is_direct_model.Rd | 23 +++++++ man/is_latent_individual.Rd | 3 +- tests/testthat/test-direct_model.R | 50 ++++++++++++++ tests/testthat/test-int-direct_model.R | 36 ++++++++++ 14 files changed, 249 insertions(+), 1 deletion(-) create mode 100644 R/direct_model.R create mode 100644 man/as_direct_model.Rd create mode 100644 man/epidist_validate.epidist_direct_model.Rd create mode 100644 man/is_direct_model.Rd create mode 100644 tests/testthat/test-direct_model.R create mode 100644 tests/testthat/test-int-direct_model.R diff --git a/NAMESPACE b/NAMESPACE index 76f052cae..70ce4e9ec 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -3,6 +3,7 @@ S3method(add_mean_sd,default) S3method(add_mean_sd,gamma_samples) S3method(add_mean_sd,lognormal_samples) +S3method(as_direct_model,data.frame) S3method(as_latent_individual,data.frame) S3method(epidist,default) S3method(epidist_family_model,default) @@ -17,9 +18,11 @@ S3method(epidist_model_prior,default) S3method(epidist_stancode,default) S3method(epidist_stancode,epidist_latent_individual) S3method(epidist_validate,default) +S3method(epidist_validate,epidist_direct_model) S3method(epidist_validate,epidist_latent_individual) export(add_event_vars) export(add_mean_sd) +export(as_direct_model) export(as_latent_individual) export(epidist) export(epidist_diagnostics) @@ -35,6 +38,7 @@ export(epidist_stancode) export(epidist_validate) export(filter_obs_by_obs_time) export(filter_obs_by_ptime) +export(is_direct_model) export(is_latent_individual) export(observe_process) export(predict_delay_parameters) diff --git a/R/direct_model.R b/R/direct_model.R new file mode 100644 index 000000000..73ff566bf --- /dev/null +++ b/R/direct_model.R @@ -0,0 +1,68 @@ +#' Prepare direct model to pass through to `brms` +#' +#' @param data A `data.frame` containing line list data +#' @family direct_model +#' @export +as_direct_model <- function(data) { + UseMethod("as_direct_model") +} + +assert_direct_model_input <- function(data) { + assert_data_frame(data) + assert_names(names(data), must.include = c("case", "ptime", "stime")) + assert_integer(data$case, lower = 0) + assert_numeric(data$ptime, lower = 0) + assert_numeric(data$stime, lower = 0) +} + +#' Prepare latent individual model +#' +#' This function prepares data for use with the direct model. It does this by +#' adding columns used in the model to the `data` object provided. To do this, +#' the `data` must already have columns for the case number (integer), +#' (positive, numeric) times for the primary and secondary event times. The +#' output of this function is a `epidist_direct_model` class object, which may +#' be passed to [epidist()] to perform inference for the model. +#' +#' @param data A `data.frame` containing line list data +#' @rdname as_direct_model +#' @method as_direct_model data.frame +#' @family direct_model +#' @autoglobal +#' @export +as_direct_model.data.frame <- function(data) { + assert_direct_model_input(data) + class(data) <- c("epidist_direct_model", class(data)) + data <- data |> + mutate(delay = .data$stime - .data$ptime) + epidist_validate(data) + return(data) +} + +#' Validate direct model data +#' +#' This function checks whether the provided `data` object is suitable for +#' running the direct model. As well as making sure that +#' `is_direct_model()` is true, it also checks that `data` is a `data.frame` +#' with the correct columns. +#' +#' @param data A `data.frame` containing line list data +#' @param ... ... +#' @method epidist_validate epidist_direct_model +#' @family direct_model +#' @export +epidist_validate.epidist_direct_model <- function(data, ...) { + assert_true(is_direct_model(data)) + assert_direct_model_input(data) + assert_names(names(data), must.include = c("case", "ptime", "stime", "delay")) + assert_numeric(data$delay, lower = 0) +} + +#' Check if data has the `epidist_direct_model` class +#' +#' @param data A `data.frame` containing line list data +#' @family latent_individual +#' @export +is_direct_model <- function(data) { + inherits(data, "epidist_direct_model") +} diff --git a/R/formula.R b/R/formula.R index a917db246..450f87813 100644 --- a/R/formula.R +++ b/R/formula.R @@ -37,5 +37,8 @@ epidist_formula_model <- function(data, formula, ...) { #' @family formula #' @export epidist_formula_model.default <- function(data, formula, ...) { + formula <- stats::update( + formula, delay ~ . + ) return(formula) } diff --git a/_pkgdown.yml b/_pkgdown.yml index f5ef239b1..06b1422ac 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -69,6 +69,10 @@ reference: desc: Specific methods for the latent individual model contents: - has_concept("latent_individual") +- title: Direct model + desc: Specific methods for the direct model + contents: + - has_concept("direct_model") - title: Postprocess desc: Functions for postprocessing model output contents: diff --git a/man/as_direct_model.Rd b/man/as_direct_model.Rd new file mode 100644 index 000000000..f2c9d8ed7 --- /dev/null +++ b/man/as_direct_model.Rd @@ -0,0 +1,30 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/direct_model.R +\name{as_direct_model} +\alias{as_direct_model} +\alias{as_direct_model.data.frame} +\title{Prepare direct model to pass through to \code{brms}} +\usage{ +as_direct_model(data) + +\method{as_direct_model}{data.frame}(data) +} +\arguments{ +\item{data}{A \code{data.frame} containing line list data} +} +\description{ +This function prepares data for use with the direct model. It does this by +adding columns used in the model to the \code{data} object provided. To do this, +the \code{data} must already have columns for the case number (integer), +(positive, numeric) times for the primary and secondary event times. The +output of this function is a \code{epidist_direct_model} class object, which may +be passed to \code{\link[=epidist]{epidist()}} to perform inference for the model. +} +\seealso{ +Other direct_model: +\code{\link{epidist_validate.epidist_direct_model}()} + +Other direct_model: +\code{\link{epidist_validate.epidist_direct_model}()} +} +\concept{direct_model} diff --git a/man/as_latent_individual.Rd b/man/as_latent_individual.Rd index 21d48e429..ad4d10fa2 100644 --- a/man/as_latent_individual.Rd +++ b/man/as_latent_individual.Rd @@ -27,12 +27,14 @@ Other latent_individual: \code{\link{epidist_family_model.epidist_latent_individual}()}, \code{\link{epidist_formula_model.epidist_latent_individual}()}, \code{\link{epidist_validate.epidist_latent_individual}()}, +\code{\link{is_direct_model}()}, \code{\link{is_latent_individual}()} Other latent_individual: \code{\link{epidist_family_model.epidist_latent_individual}()}, \code{\link{epidist_formula_model.epidist_latent_individual}()}, \code{\link{epidist_validate.epidist_latent_individual}()}, +\code{\link{is_direct_model}()}, \code{\link{is_latent_individual}()} } \concept{latent_individual} diff --git a/man/epidist_family_model.epidist_latent_individual.Rd b/man/epidist_family_model.epidist_latent_individual.Rd index 007786633..52e2ea8ea 100644 --- a/man/epidist_family_model.epidist_latent_individual.Rd +++ b/man/epidist_family_model.epidist_latent_individual.Rd @@ -22,6 +22,7 @@ Other latent_individual: \code{\link{as_latent_individual}()}, \code{\link{epidist_formula_model.epidist_latent_individual}()}, \code{\link{epidist_validate.epidist_latent_individual}()}, +\code{\link{is_direct_model}()}, \code{\link{is_latent_individual}()} } \concept{latent_individual} diff --git a/man/epidist_formula_model.epidist_latent_individual.Rd b/man/epidist_formula_model.epidist_latent_individual.Rd index 885264cfc..60fa7251e 100644 --- a/man/epidist_formula_model.epidist_latent_individual.Rd +++ b/man/epidist_formula_model.epidist_latent_individual.Rd @@ -21,6 +21,7 @@ Other latent_individual: \code{\link{as_latent_individual}()}, \code{\link{epidist_family_model.epidist_latent_individual}()}, \code{\link{epidist_validate.epidist_latent_individual}()}, +\code{\link{is_direct_model}()}, \code{\link{is_latent_individual}()} } \concept{latent_individual} diff --git a/man/epidist_validate.epidist_direct_model.Rd b/man/epidist_validate.epidist_direct_model.Rd new file mode 100644 index 000000000..fb43f6b9b --- /dev/null +++ b/man/epidist_validate.epidist_direct_model.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/direct_model.R +\name{epidist_validate.epidist_direct_model} +\alias{epidist_validate.epidist_direct_model} +\title{Validate direct model data} +\usage{ +\method{epidist_validate}{epidist_direct_model}(data, ...) +} +\arguments{ +\item{data}{A \code{data.frame} containing line list data} + +\item{...}{...} +} +\description{ +This function checks whether the provided \code{data} object is suitable for +running the direct model. As well as making sure that +\code{is_direct_model()} is true, it also checks that \code{data} is a \code{data.frame} +with the correct columns. +} +\seealso{ +Other direct_model: +\code{\link{as_direct_model}()} +} +\concept{direct_model} diff --git a/man/epidist_validate.epidist_latent_individual.Rd b/man/epidist_validate.epidist_latent_individual.Rd index 986b13ded..0c0b33def 100644 --- a/man/epidist_validate.epidist_latent_individual.Rd +++ b/man/epidist_validate.epidist_latent_individual.Rd @@ -22,6 +22,7 @@ Other latent_individual: \code{\link{as_latent_individual}()}, \code{\link{epidist_family_model.epidist_latent_individual}()}, \code{\link{epidist_formula_model.epidist_latent_individual}()}, +\code{\link{is_direct_model}()}, \code{\link{is_latent_individual}()} } \concept{latent_individual} diff --git a/man/is_direct_model.Rd b/man/is_direct_model.Rd new file mode 100644 index 000000000..18ef5d24f --- /dev/null +++ b/man/is_direct_model.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/direct_model.R +\name{is_direct_model} +\alias{is_direct_model} +\title{Check if data has the \code{epidist_direct_model} class} +\usage{ +is_direct_model(data) +} +\arguments{ +\item{data}{A \code{data.frame} containing line list data} +} +\description{ +Check if data has the \code{epidist_direct_model} class +} +\seealso{ +Other latent_individual: +\code{\link{as_latent_individual}()}, +\code{\link{epidist_family_model.epidist_latent_individual}()}, +\code{\link{epidist_formula_model.epidist_latent_individual}()}, +\code{\link{epidist_validate.epidist_latent_individual}()}, +\code{\link{is_latent_individual}()} +} +\concept{latent_individual} diff --git a/man/is_latent_individual.Rd b/man/is_latent_individual.Rd index e42e351ea..204d510e9 100644 --- a/man/is_latent_individual.Rd +++ b/man/is_latent_individual.Rd @@ -17,6 +17,7 @@ Other latent_individual: \code{\link{as_latent_individual}()}, \code{\link{epidist_family_model.epidist_latent_individual}()}, \code{\link{epidist_formula_model.epidist_latent_individual}()}, -\code{\link{epidist_validate.epidist_latent_individual}()} +\code{\link{epidist_validate.epidist_latent_individual}()}, +\code{\link{is_direct_model}()} } \concept{latent_individual} diff --git a/tests/testthat/test-direct_model.R b/tests/testthat/test-direct_model.R new file mode 100644 index 000000000..99bd51e51 --- /dev/null +++ b/tests/testthat/test-direct_model.R @@ -0,0 +1,50 @@ +test_that("as_direct_model.data.frame with default settings an object with the correct classes", { # nolint: line_length_linter. + prep_obs <- as_direct_model(sim_obs) + expect_s3_class(prep_obs, "data.frame") + expect_s3_class(prep_obs, "epidist_direct_model") +}) + +test_that("as_direct_model.data.frame errors when passed incorrect inputs", { # nolint: line_length_linter. + expect_error(as_direct_model(list())) + expect_error(as_direct_model(sim_obs[, 1])) + expect_error({ + sim_obs$case <- paste("case_", seq_len(nrow(sim_obs))) + as_direct_model(sim_obs) + }) +}) + +# Make this data available for other tests +prep_obs <- as_direct_model(sim_obs) +family_lognormal <- epidist_family(prep_obs, family = brms::lognormal()) + +test_that("is_direct_model returns TRUE for correct input", { # nolint: line_length_linter. + expect_true(is_direct_model(prep_obs)) + expect_true({ + x <- list() + class(x) <- "epidist_direct_model" + is_direct_model(x) + }) +}) + +test_that("is_direct_model returns FALSE for incorrect input", { # nolint: line_length_linter. + expect_false(is_direct_model(list())) + expect_false({ + x <- list() + class(x) <- "epidist_direct_model_extension" + is_direct_model(x) + }) +}) + +test_that("epidist_validate.epidist_direct_model doesn't produce an error for correct input", { # nolint: line_length_linter. + expect_no_error(epidist_validate(prep_obs)) +}) + +test_that("epidist_validate.epidist_direct_model returns FALSE for incorrect input", { # nolint: line_length_linter. + expect_error(epidist_validate(list())) + expect_error(epidist_validate(prep_obs[, 1])) + expect_error({ + x <- list() + class(x) <- "epidist_direct_model" + epidist_validate(x) + }) +}) diff --git a/tests/testthat/test-int-direct_model.R b/tests/testthat/test-int-direct_model.R new file mode 100644 index 000000000..b366b938f --- /dev/null +++ b/tests/testthat/test-int-direct_model.R @@ -0,0 +1,36 @@ +# Note: some tests in this script are stochastic. As such, test failure may be +# bad luck rather than indicate an issue with the code. However, as these tests +# are reproducible, the distribution of test failures may be investigated by +# varying the input seed. Test failure at an unusually high rate does suggest +# a potential code issue. + +prep_obs <- as_direct_model(sim_obs) + +test_that("epidist.epidist_direct_model Stan code has no syntax errors and compiles in the default case", { # nolint: line_length_linter. + skip_on_cran() + stancode <- epidist( + data = prep_obs, + fn = brms::make_stancode, + output_dir = fs::dir_create(tempfile()) + ) + mod <- cmdstanr::cmdstan_model( + stan_file = cmdstanr::write_stan_file(stancode), compile = FALSE + ) + expect_true(mod$check_syntax()) + expect_no_error(mod$compile()) +}) + +test_that("epidist.epidist_direct_model fits and the MCMC converges in the default case", { # nolint: line_length_linter. + # Note: this test is stochastic. See note at the top of this script + skip_on_cran() + set.seed(1) + fit <- epidist( + data = prep_obs, + seed = 1, + silent = 2, + output_dir = fs::dir_create(tempfile()) + ) + expect_s3_class(fit, "brmsfit") + expect_s3_class(fit, "epidist_fit") + expect_convergence(fit) +})