From 7c972261fa458a5da38b6cd8ab9b33a572211858 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 2 Dec 2024 15:22:16 +0000 Subject: [PATCH] add marginal model integration tests --- R/marginal_model.R | 5 +- tests/testthat/setup.R | 22 +++++++- tests/testthat/test-int-marginal_model.R | 72 ++++++++++++++++++++---- tests/testthat/test-utils.R | 55 ++++++++++++++++++ 4 files changed, 139 insertions(+), 15 deletions(-) diff --git a/R/marginal_model.R b/R/marginal_model.R index a75278be8..7723140b6 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -167,9 +167,10 @@ epidist_transform_data_model.epidist_marginal_model <- function( if (n_rows_before > n_rows_after) { cli::cli_inform("Data summarised by unique combinations of:") - if (length(all.vars(formula[[3]])) > 0) { + formula_vars <- setdiff(names(trans_data), c(required_cols, "n")) + if (length(formula_vars) > 0) { cli::cli_inform( - paste0("* Formula variables: {.code {all.vars(formula[[3]])}}") + paste0("* Formula variables: {.code {formula_vars}}") ) } diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 03271e4cf..f9d76a014 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -113,13 +113,16 @@ sim_obs_sex <- as_epidist_linelist_data( sim_obs_sex$obs_time, sex = sim_obs_sex$sex ) - prep_obs <- as_epidist_latent_model(sim_obs) prep_naive_obs <- as_epidist_naive_model(sim_obs) prep_marginal_obs <- as_epidist_marginal_model(sim_obs) prep_obs_gamma <- as_epidist_latent_model(sim_obs_gamma) prep_obs_sex <- as_epidist_latent_model(sim_obs_sex) +prep_marginal_obs <- as_epidist_marginal_model(sim_obs) +prep_marginal_obs_gamma <- as_epidist_marginal_model(sim_obs_gamma) +prep_marginal_obs_sex <- as_epidist_marginal_model(sim_obs_sex) + if (not_on_cran()) { set.seed(1) fit <- epidist( @@ -130,6 +133,10 @@ if (not_on_cran()) { fit_rstan <- epidist( data = prep_obs, seed = 1, chains = 2, cores = 2, silent = 2, refresh = 0 ) + fit_marginal <- suppressMessages(epidist( + data = prep_marginal_obs, seed = 1, chains = 2, cores = 2, silent = 2, + refresh = 0, backend = "cmdstanr" + )) fit_gamma <- epidist( data = prep_obs_gamma, family = Gamma(link = "log"), @@ -137,10 +144,23 @@ if (not_on_cran()) { backend = "cmdstanr" ) + fit_marginal_gamma <- suppressMessages(epidist( + data = prep_marginal_obs_gamma, family = Gamma(link = "log"), + seed = 1, chains = 2, cores = 2, silent = 2, refresh = 0, + backend = "cmdstanr" + )) + fit_sex <- epidist( data = prep_obs_sex, formula = bf(mu ~ 1 + sex, sigma ~ 1 + sex), seed = 1, silent = 2, refresh = 0, cores = 2, chains = 2, backend = "cmdstanr" ) + + fit_marginal_sex <- suppressMessages(epidist( + data = prep_marginal_obs_sex, + formula = bf(mu ~ 1 + sex, sigma ~ 1 + sex), + seed = 1, silent = 2, refresh = 50, + cores = 2, chains = 2, backend = "cmdstanr" + )) } diff --git a/tests/testthat/test-int-marginal_model.R b/tests/testthat/test-int-marginal_model.R index 97099b747..d44131b90 100644 --- a/tests/testthat/test-int-marginal_model.R +++ b/tests/testthat/test-int-marginal_model.R @@ -6,10 +6,10 @@ test_that("epidist.epidist_marginal_model Stan code has no syntax errors in the default case", { # nolint: line_length_linter. skip_on_cran() - stancode <- epidist( + stancode <- suppressMessages(epidist( data = prep_marginal_obs, fn = brms::make_stancode - ) + )) mod <- cmdstanr::cmdstan_model( stan_file = cmdstanr::write_stan_file(stancode), compile = FALSE ) @@ -17,18 +17,66 @@ test_that("epidist.epidist_marginal_model Stan code has no syntax errors in the }) test_that("epidist.epidist_marginal_model fits and the MCMC converges in the default case", { # nolint: line_length_linter. + # Note: this test is stochastic. See note at the top of this script + skip_on_cran() + expect_s3_class(fit_marginal, "brmsfit") + expect_s3_class(fit_marginal, "epidist_fit") + expect_convergence(fit_marginal) +}) + +test_that("epidist.epidist_marginal_model recovers the simulation settings for the delay distribution in the default case", { # nolint: line_length_linter. # Note: this test is stochastic. See note at the top of this script skip_on_cran() set.seed(1) - fit <- epidist( - data = prep_marginal_obs, - seed = 1, - silent = 2, refresh = 0, - cores = 2, - chains = 2, - backend = "cmdstanr" + pred <- predict_delay_parameters(fit_marginal) + expect_equal(mean(pred$mu), meanlog, tolerance = 0.1) + expect_equal(mean(pred$sigma), sdlog, tolerance = 0.1) +}) + +test_that("epidist.epidist_marginal_model fits and the MCMC converges in the gamma delay case", { # nolint: line_length_linter. + # Note: this test is stochastic. See note at the top of this script + skip_on_cran() + set.seed(1) + expect_s3_class(fit_marginal_gamma, "brmsfit") + expect_s3_class(fit_marginal_gamma, "epidist_fit") + expect_convergence(fit_marginal_gamma) +}) + +test_that("epidist.epidist_marginal_model recovers the simulation settings for the delay distribution in the gamma delay case", { # nolint: line_length_linter. + # Note: this test is stochastic. See note at the top of this script + skip_on_cran() + set.seed(1) + draws_gamma <- posterior::as_draws_df(fit_marginal_gamma$fit) + draws_gamma_mu <- exp(draws_gamma$Intercept) + draws_gamma_shape <- exp(draws_gamma$Intercept_shape) + draws_gamma_mu_ecdf <- ecdf(draws_gamma_mu) + draws_gamma_shape_ecdf <- ecdf(draws_gamma_shape) + quantile_mu <- draws_gamma_mu_ecdf(mu) + quantile_shape <- draws_gamma_shape_ecdf(shape) + expect_gte(quantile_mu, 0.025) + expect_lte(quantile_mu, 0.975) + expect_gte(quantile_shape, 0.025) + expect_lte(quantile_shape, 0.975) +}) + +test_that("epidist.epidist_marginal_model fits and recovers a sex effect", { # nolint: line_length_linter. + # Note: this test is stochastic. See note at the top of this script + skip_on_cran() + set.seed(1) + expect_s3_class(fit_marginal_sex, "brmsfit") + expect_s3_class(fit_marginal_sex, "epidist_fit") + expect_convergence(fit_marginal_sex) + + draws <- posterior::as_draws_df(fit_marginal_sex$fit) + expect_equal(mean(draws$b_Intercept), meanlog_m, tolerance = 0.3) + expect_equal( + mean(draws$b_Intercept + draws$b_sex), meanlog_f, + tolerance = 0.3 + ) + expect_equal(mean(exp(draws$b_sigma_Intercept)), sdlog_m, tolerance = 0.3) + expect_equal( + mean(exp(draws$b_sigma_Intercept + draws$b_sigma_sex)), + sdlog_f, + tolerance = 0.3 ) - expect_s3_class(fit, "brmsfit") - expect_s3_class(fit, "epidist_fit") - expect_convergence(fit) }) diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index 1fb3add13..e136db754 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -98,3 +98,58 @@ test_that(".make_intercepts_explicit does not add an intercept if the distributi expect_identical(formula$pforms$mu, formula_updated$pforms$mu) expect_identical(formula$pforms$sigma, formula_updated$pforms$sigma) }) + +test_that( + ".summarise_n_by_formula correctly summarizes counts by grouping variables", + { + df <- tibble::tibble( + x = c(1, 1, 2, 2), + y = c("a", "b", "a", "b"), + n = c(2, 3, 4, 1) + ) + + # Test grouping by single variable + result <- .summarise_n_by_formula(df, by = "x") + expect_identical(result$x, c(1, 2)) + expect_identical(result$n, c(5, 5)) + + # Test grouping by multiple variable + result <- .summarise_n_by_formula(df, by = c("x", "y")) + expect_identical(result$x, c(1, 1, 2, 2)) + expect_identical(result$y, c("a", "b", "a", "b")) + expect_identical(result$n, c(2, 3, 4, 1)) + + # Test with formula + formula <- bf(mu ~ x + y) + result <- .summarise_n_by_formula(df, formula = formula) + expect_identical(result$x, c(1, 1, 2, 2)) + expect_identical(result$y, c("a", "b", "a", "b")) + expect_identical(result$n, c(2, 3, 4, 1)) + + # Test with both by and formula + formula <- bf(mu ~ y) + result <- .summarise_n_by_formula(df, by = "x", formula = formula) + expect_identical(result$x, c(1, 1, 2, 2)) + expect_identical(result$y, c("a", "b", "a", "b")) + expect_identical(result$n, c(2, 3, 4, 1)) + } +) + +test_that( + ".summarise_n_by_formula handles missing grouping variables appropriately", + { + df <- data.frame(x = 1:2, n = c(1, 2)) + expect_error( + .summarise_n_by_formula(df, by = "missing"), + "object 'missing' not found" + ) + } +) + +test_that(".summarise_n_by_formula requires n column in data", { + df <- data.frame(x = 1:2) + expect_error( + .summarise_n_by_formula(df, by = "x"), + "Column `n` not found in `.data`." + ) +})