Skip to content

Commit

Permalink
add marginal model integration tests
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs committed Dec 2, 2024
1 parent 4d98a7b commit 7c97226
Show file tree
Hide file tree
Showing 4 changed files with 139 additions and 15 deletions.
5 changes: 3 additions & 2 deletions R/marginal_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -167,9 +167,10 @@ epidist_transform_data_model.epidist_marginal_model <- function(
if (n_rows_before > n_rows_after) {
cli::cli_inform("Data summarised by unique combinations of:")

if (length(all.vars(formula[[3]])) > 0) {
formula_vars <- setdiff(names(trans_data), c(required_cols, "n"))
if (length(formula_vars) > 0) {
cli::cli_inform(
paste0("* Formula variables: {.code {all.vars(formula[[3]])}}")
paste0("* Formula variables: {.code {formula_vars}}")
)
}

Expand Down
22 changes: 21 additions & 1 deletion tests/testthat/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -113,13 +113,16 @@ sim_obs_sex <- as_epidist_linelist_data(
sim_obs_sex$obs_time,
sex = sim_obs_sex$sex
)

prep_obs <- as_epidist_latent_model(sim_obs)
prep_naive_obs <- as_epidist_naive_model(sim_obs)
prep_marginal_obs <- as_epidist_marginal_model(sim_obs)
prep_obs_gamma <- as_epidist_latent_model(sim_obs_gamma)
prep_obs_sex <- as_epidist_latent_model(sim_obs_sex)

prep_marginal_obs <- as_epidist_marginal_model(sim_obs)
prep_marginal_obs_gamma <- as_epidist_marginal_model(sim_obs_gamma)
prep_marginal_obs_sex <- as_epidist_marginal_model(sim_obs_sex)

if (not_on_cran()) {
set.seed(1)
fit <- epidist(
Expand All @@ -130,17 +133,34 @@ if (not_on_cran()) {
fit_rstan <- epidist(
data = prep_obs, seed = 1, chains = 2, cores = 2, silent = 2, refresh = 0
)
fit_marginal <- suppressMessages(epidist(
data = prep_marginal_obs, seed = 1, chains = 2, cores = 2, silent = 2,
refresh = 0, backend = "cmdstanr"
))

fit_gamma <- epidist(
data = prep_obs_gamma, family = Gamma(link = "log"),
seed = 1, chains = 2, cores = 2, silent = 2, refresh = 0,
backend = "cmdstanr"
)

fit_marginal_gamma <- suppressMessages(epidist(
data = prep_marginal_obs_gamma, family = Gamma(link = "log"),
seed = 1, chains = 2, cores = 2, silent = 2, refresh = 0,
backend = "cmdstanr"
))

fit_sex <- epidist(
data = prep_obs_sex,
formula = bf(mu ~ 1 + sex, sigma ~ 1 + sex),
seed = 1, silent = 2, refresh = 0,
cores = 2, chains = 2, backend = "cmdstanr"
)

fit_marginal_sex <- suppressMessages(epidist(
data = prep_marginal_obs_sex,
formula = bf(mu ~ 1 + sex, sigma ~ 1 + sex),
seed = 1, silent = 2, refresh = 50,
cores = 2, chains = 2, backend = "cmdstanr"
))
}
72 changes: 60 additions & 12 deletions tests/testthat/test-int-marginal_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,29 +6,77 @@

test_that("epidist.epidist_marginal_model Stan code has no syntax errors in the default case", { # nolint: line_length_linter.
skip_on_cran()
stancode <- epidist(
stancode <- suppressMessages(epidist(
data = prep_marginal_obs,
fn = brms::make_stancode
)
))
mod <- cmdstanr::cmdstan_model(
stan_file = cmdstanr::write_stan_file(stancode), compile = FALSE
)
expect_true(mod$check_syntax())
})

test_that("epidist.epidist_marginal_model fits and the MCMC converges in the default case", { # nolint: line_length_linter.
# Note: this test is stochastic. See note at the top of this script
skip_on_cran()
expect_s3_class(fit_marginal, "brmsfit")
expect_s3_class(fit_marginal, "epidist_fit")
expect_convergence(fit_marginal)
})

test_that("epidist.epidist_marginal_model recovers the simulation settings for the delay distribution in the default case", { # nolint: line_length_linter.
# Note: this test is stochastic. See note at the top of this script
skip_on_cran()
set.seed(1)
fit <- epidist(
data = prep_marginal_obs,
seed = 1,
silent = 2, refresh = 0,
cores = 2,
chains = 2,
backend = "cmdstanr"
pred <- predict_delay_parameters(fit_marginal)
expect_equal(mean(pred$mu), meanlog, tolerance = 0.1)
expect_equal(mean(pred$sigma), sdlog, tolerance = 0.1)
})

test_that("epidist.epidist_marginal_model fits and the MCMC converges in the gamma delay case", { # nolint: line_length_linter.
# Note: this test is stochastic. See note at the top of this script
skip_on_cran()
set.seed(1)
expect_s3_class(fit_marginal_gamma, "brmsfit")
expect_s3_class(fit_marginal_gamma, "epidist_fit")
expect_convergence(fit_marginal_gamma)
})

test_that("epidist.epidist_marginal_model recovers the simulation settings for the delay distribution in the gamma delay case", { # nolint: line_length_linter.
# Note: this test is stochastic. See note at the top of this script
skip_on_cran()
set.seed(1)
draws_gamma <- posterior::as_draws_df(fit_marginal_gamma$fit)
draws_gamma_mu <- exp(draws_gamma$Intercept)
draws_gamma_shape <- exp(draws_gamma$Intercept_shape)
draws_gamma_mu_ecdf <- ecdf(draws_gamma_mu)
draws_gamma_shape_ecdf <- ecdf(draws_gamma_shape)
quantile_mu <- draws_gamma_mu_ecdf(mu)
quantile_shape <- draws_gamma_shape_ecdf(shape)
expect_gte(quantile_mu, 0.025)
expect_lte(quantile_mu, 0.975)
expect_gte(quantile_shape, 0.025)
expect_lte(quantile_shape, 0.975)
})

test_that("epidist.epidist_marginal_model fits and recovers a sex effect", { # nolint: line_length_linter.
# Note: this test is stochastic. See note at the top of this script
skip_on_cran()
set.seed(1)
expect_s3_class(fit_marginal_sex, "brmsfit")
expect_s3_class(fit_marginal_sex, "epidist_fit")
expect_convergence(fit_marginal_sex)

draws <- posterior::as_draws_df(fit_marginal_sex$fit)
expect_equal(mean(draws$b_Intercept), meanlog_m, tolerance = 0.3)
expect_equal(
mean(draws$b_Intercept + draws$b_sex), meanlog_f,
tolerance = 0.3
)
expect_equal(mean(exp(draws$b_sigma_Intercept)), sdlog_m, tolerance = 0.3)
expect_equal(
mean(exp(draws$b_sigma_Intercept + draws$b_sigma_sex)),
sdlog_f,
tolerance = 0.3
)
expect_s3_class(fit, "brmsfit")
expect_s3_class(fit, "epidist_fit")
expect_convergence(fit)
})
55 changes: 55 additions & 0 deletions tests/testthat/test-utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,58 @@ test_that(".make_intercepts_explicit does not add an intercept if the distributi
expect_identical(formula$pforms$mu, formula_updated$pforms$mu)
expect_identical(formula$pforms$sigma, formula_updated$pforms$sigma)
})

test_that(
".summarise_n_by_formula correctly summarizes counts by grouping variables",
{
df <- tibble::tibble(
x = c(1, 1, 2, 2),
y = c("a", "b", "a", "b"),
n = c(2, 3, 4, 1)
)

# Test grouping by single variable
result <- .summarise_n_by_formula(df, by = "x")
expect_identical(result$x, c(1, 2))
expect_identical(result$n, c(5, 5))

# Test grouping by multiple variable
result <- .summarise_n_by_formula(df, by = c("x", "y"))
expect_identical(result$x, c(1, 1, 2, 2))
expect_identical(result$y, c("a", "b", "a", "b"))
expect_identical(result$n, c(2, 3, 4, 1))

# Test with formula
formula <- bf(mu ~ x + y)
result <- .summarise_n_by_formula(df, formula = formula)
expect_identical(result$x, c(1, 1, 2, 2))
expect_identical(result$y, c("a", "b", "a", "b"))
expect_identical(result$n, c(2, 3, 4, 1))

# Test with both by and formula
formula <- bf(mu ~ y)
result <- .summarise_n_by_formula(df, by = "x", formula = formula)
expect_identical(result$x, c(1, 1, 2, 2))
expect_identical(result$y, c("a", "b", "a", "b"))
expect_identical(result$n, c(2, 3, 4, 1))
}
)

test_that(
".summarise_n_by_formula handles missing grouping variables appropriately",
{
df <- data.frame(x = 1:2, n = c(1, 2))
expect_error(
.summarise_n_by_formula(df, by = "missing"),
"object 'missing' not found"
)
}
)

test_that(".summarise_n_by_formula requires n column in data", {
df <- data.frame(x = 1:2)
expect_error(
.summarise_n_by_formula(df, by = "x"),
"Column `n` not found in `.data`."
)
})

0 comments on commit 7c97226

Please sign in to comment.