Skip to content

Commit

Permalink
add tests for sensitivity analysis
Browse files Browse the repository at this point in the history
  • Loading branch information
krisrs1128 committed Sep 8, 2024
1 parent bff71d3 commit 88c92e8
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 2 deletions.
4 changes: 2 additions & 2 deletions R/sensitivity.R
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ covariance_matrix <- function(model, confound_ix = NULL, rho = 0.0) {
#' outcome_estimator = glmnet_model(lambda = 1e-2)
#' ) |>
#' estimate(exper)
#' rho_seq <- c(-0.2, 0.2)
#' nu_seq <- c(-0.2, 0.2)
#' perturb <- matrix(
#' c(
#' 0, 3, 0,
Expand All @@ -419,7 +419,7 @@ covariance_matrix <- function(model, confound_ix = NULL, rho = 0.0) {
#' ),
#' nrow = 3, byrow = TRUE
#' )
#' sensitivity_perturb(model, exper, perturb, n_bootstrap = 2)
#' sensitivity_perturb(model, exper, perturb, nu_seq, n_bootstrap = 2)
#' @export
sensitivity_perturb <- function(
model, exper, perturb, nu_seq = NULL, n_bootstrap = 100, progress = TRUE) {
Expand Down
42 changes: 42 additions & 0 deletions tests/testthat/test-sensitivity.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,46 @@ test_that("Raises error on inappropriate model input.", {

model <- multimedia(exper, brms_model())
expect_error(check_supported(model))
})

pathwise <- sensitivity_pathwise(model, exper, subset_indices, rho_seq, n_bootstrap = 2)

test_that("Pathwise sensitivity curve has correct column names.", {
expect_named(
pathwise,
c("rho", "outcome", "direct_setting", "contrast", "mediator",
"indirect_effect", "indirect_effect_standard_error")
)
})

test_that("Pathwise sensitivity covers all mediators", {
expect_equal(nrow(pathwise), 4 * length(rho_seq))
expect_true(all(mediators(model) %in% pathwise$mediator))
expect_true(all(outcomes(model) %in% pathwise$outcome))
})

perturb <- matrix(
c(
0, 3, 0,
3, 0, 0,
0, 0, 0
),
nrow = 3, byrow = TRUE
)
nu_seq <- c(-0.2, 0, 0.2)
curve <- sensitivity_perturb(model, exper, perturb, nu_seq, n_bootstrap = 2)

test_that("Perturbation sensitivity has the correct column names", {
expect_named(
curve,
c("nu", colnames(overall), "indirect_effect_standard_error")
)
})

test_that("All perturbation values appear in the sensitivity curve.", {
expect_equal(unique(curve$nu), nu_seq)
})

test_that("All outcomes appear in the sensitivity curve.", {
expect_equal(unique(curve$outcome), outcomes(model))
})

0 comments on commit 88c92e8

Please sign in to comment.