Skip to content

Commit

Permalink
First draft on moving marginal_model into functions
Browse files Browse the repository at this point in the history
  • Loading branch information
athowes committed Nov 22, 2024
1 parent e14e880 commit bd6c7be
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 1 deletion.
63 changes: 63 additions & 0 deletions R/marginal_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -59,3 +59,66 @@ assert_epidist.epidist_marginal_model <- function(data, ...) {
is_epidist_marginal_model <- function(data) {
inherits(data, "epidist_marginal_model")
}

#' Create the model-specific component of an `epidist` custom family
#'
#' @inheritParams epidist_family_model
#' @param ... Additional arguments passed to method.
#' @method epidist_family_model epidist_marginal_model
#' @family marginal_model
#' @export
epidist_family_model.epidist_marginal_model <- function(
data, family, ...) {
custom_family <- brms::custom_family(
"primarycensored_wrapper",
dpars = family$dpars,
links = c(family$link, family$other_links),
lb = c(NA, as.numeric(lapply(family$other_bounds, "[[", "lb"))),
ub = c(NA, as.numeric(lapply(family$other_bounds, "[[", "ub"))),
type = "int",
loop = TRUE,
vars = "vreal1[n]"
)
return(custom_family)
}

#' Define the model-specific component of an `epidist` custom formula
#'
#' @inheritParams epidist_formula_model
#' @param ... Additional arguments passed to method.
#' @method epidist_formula_model epidist_marginal_model
#' @family marginal_model
#' @export
epidist_formula_model.epidist_marginal_model <- function(
data, formula, ...) {
# data is only used to dispatch on
formula <- stats::update(
formula, delay | weights(n) + vreal(pwindow) ~ .
)
return(formula)
}

#' @method epidist_stancode epidist_marginal_model
#' @importFrom brms stanvar
#' @family marginal_model
#' @autoglobal
#' @export
epidist_stancode.epidist_marginal_model <- function(data, ...) {
assert_epidist(data)

stanvars_version <- .version_stanvar()

stanvars_functions <- brms::stanvar(
block = "functions",
scode = .stan_chunk(file.path("marginal_model", "functions.stan"))
)

pcd_stanvars_functions <- brms::stanvar(
block = "functions",
scode = pcd_load_stan_functions()
)

stanvars_all <- stanvars_version + stanvars_functions + pcd_stanvars_functions

return(stanvars_all)
}
42 changes: 41 additions & 1 deletion tests/testthat/test-marginal_model.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,45 @@
test_that("as_epidist_marginal_model.epidist_linelist_data with default settings an object with the correct classes", { # nolint: line_length_linter.
prep_obs <- as_epidist_marginal_model(sim_obs)
expect_s3_class(prep_obs, "data.frame")
expect_s3_class(prep_obs, "epidist_latent_model")
expect_s3_class(prep_obs, "epidist_marginal_model")
})

test_that("as_epidist_marginal_model.epidist_linelist_data when passed incorrect inputs", { # nolint: line_length_linter.
expect_error(as_epidist_marginal_model(list()))
expect_error(as_epidist_marginal_model(sim_obs[, 1]))
})

# Make this data available for other tests
family_lognormal <- epidist_family(prep_obs, family = brms::lognormal())

test_that("is_epidist_marginal_model returns TRUE for correct input", { # nolint: line_length_linter.
expect_true(is_epidist_marginal_model(prep_obs))
expect_true({
x <- list()
class(x) <- "epidist_marginal_model"
is_epidist_marginal_model(x)
})
})

test_that("is_epidist_marginal_model returns FALSE for incorrect input", { # nolint: line_length_linter.
expect_false(is_epidist_marginal_model(list()))
expect_false({
x <- list()
class(x) <- "epidist_marginal_model_extension"
is_epidist_marginal_model(x)
})
})

test_that("assert_epidist.epidist_marginal_model doesn't produce an error for correct input", { # nolint: line_length_linter.
expect_no_error(assert_epidist(prep_obs))
})

test_that("assert_epidist.epidist_marginal_model returns FALSE for incorrect input", { # nolint: line_length_linter.
expect_error(assert_epidist(list()))
expect_error(assert_epidist(prep_obs[, 1]))
expect_error({
x <- list()
class(x) <- "epidist_marginal_model"
assert_epidist(x)
})
})

0 comments on commit bd6c7be

Please sign in to comment.