diff --git a/tests/testthat/test-gen.R b/tests/testthat/test-gen.R index 3e6a246b5..f51162663 100644 --- a/tests/testthat/test-gen.R +++ b/tests/testthat/test-gen.R @@ -104,13 +104,10 @@ test_that("epidist_gen_posterior_predict returns a function that predicts delays test_that("epidist_gen_posterior_epred returns a function that creates arrays with correct dimensions", { # nolint: line_length_linter. skip_on_cran() # Test lognormal - prep <- brms::prepare_predictions(fit) - epred_fn <- epidist_gen_posterior_epred(lognormal()) - epred <- epred_fn(prep) - expect_setequal(class(epred), c("matrix", "array")) - expect_identical(nrow(epred), prep$ndraws) - expect_identical(ncol(epred), length(prep$data$Y)) - expect_gte(min(epred), 0) + epred <- prep_obs |> + tidybayes::add_epred_draws(fit) + expect_equal(mean(epred$.epred), 5.97, tolerance = 0.1) + expect_gte(min(epred$.epred), 0) # Test gamma prep_gamma <- brms::prepare_predictions(fit_gamma) diff --git a/tests/testthat/test-latent_model.R b/tests/testthat/test-latent_model.R index ca8228380..c1a64560f 100644 --- a/tests/testthat/test-latent_model.R +++ b/tests/testthat/test-latent_model.R @@ -62,8 +62,7 @@ test_that("epidist_gen_log_lik_latent returns a function that produces valid log # Test lognormal prep <- brms::prepare_predictions(fit) i <- 1 - family <- epidist_family(data = prep_obs, family = lognormal()) - log_lik_fn <- epidist_gen_log_lik_latent(family) + log_lik_fn <- epidist_gen_log_lik_latent(lognormal()) log_lik <- log_lik_fn(i = i, prep) expect_length(log_lik, prep$ndraws) expect_false(anyNA(log_lik)) @@ -71,8 +70,7 @@ test_that("epidist_gen_log_lik_latent returns a function that produces valid log # Test gamma prep_gamma <- brms::prepare_predictions(fit_gamma) - family_gamma <- epidist_family(data = prep_obs, family = Gamma()) - log_lik_fn_gamma <- epidist_gen_log_lik_latent(family_gamma) + log_lik_fn_gamma <- epidist_gen_log_lik_latent(Gamma()) log_lik_gamma <- log_lik_fn_gamma(i = i, prep_gamma) expect_length(log_lik_gamma, prep_gamma$ndraws) expect_false(anyNA(log_lik_gamma)) diff --git a/vignettes/ebola.Rmd b/vignettes/ebola.Rmd index be5807520..f66f1a5b0 100644 --- a/vignettes/ebola.Rmd +++ b/vignettes/ebola.Rmd @@ -320,7 +320,6 @@ Figure \@ref(fig:epred)B illustrates the higher mean of men as compared with wom ```{r} epred_draws <- obs_prep |> - as.data.frame() |> data_grid(NA) |> mutate(relative_obs_time = NA, pwindow = NA, swindow = NA) |> add_epred_draws(fit, dpar = TRUE) @@ -332,7 +331,6 @@ epred_base_figure <- epred_draws |> theme_minimal() epred_draws_sex <- obs_prep |> - as.data.frame() |> data_grid(sex) |> mutate(relative_obs_time = NA, pwindow = NA, swindow = NA) |> add_epred_draws(fit_sex, dpar = TRUE) @@ -344,7 +342,6 @@ epred_sex_figure <- epred_draws_sex |> theme_minimal() epred_draws_sex_district <- obs_prep |> - as.data.frame() |> data_grid(sex, district) |> mutate(relative_obs_time = NA, pwindow = NA, swindow = NA) |> add_epred_draws(fit_sex_district, dpar = TRUE)