-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
epidist_diagnostics
(#175)
* Remove sample_model * Start to write epidist_diagnostics * Rebase merged PRs * Don't need this to be S3 * Finish removal of S3 and improve tests * Replace custom code with diagnostics function * Lint and documentation fixes * Add documentation to epidist_diagnostics * Lint and add diagnostics as section * Add more unit tests for diagnostics * expect_numeric doesn't exist * Make tests softer
- Loading branch information
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,59 @@ | ||
#' Diagnostics for `epidist_fit` models | ||
#' | ||
#' This function computes diagnostics to assess the quality of a fitted model. | ||
#' When the fitting algorithm used is `"sampling"` (HMC) then the output of | ||
#' `epidist_diagnostics` is a `data.table` containing: | ||
#' | ||
#' * `time`: the total time taken to fit all chains | ||
#' * `samples`: the total number of samples across all chains | ||
#' * `max_rhat`: the highest value of the Gelman-Rubin statistic | ||
#' * `divergent_transitions`: the total number of divergent transitions | ||
#' * `per_divergent_transitions`: the proportion of samples which had divergent | ||
#' transitions | ||
#' * `max_treedepth`: the highest value of the treedepth HMC parameter | ||
#' * `no_at_max_treedepth`: the number of samples which attained the | ||
#' `max_treedepth` | ||
#' * `per_at_max_treedepth`: the proportion of samples which attained the | ||
#' `max_treedepth` | ||
#' | ||
#' When the fitting algorithm is not `"sampling"` (see `?brms::brm` for other | ||
#' possible algorithms) then diagnostics are yet to be implemented. | ||
#' @param fit A fitted model of class `epidist_fit` | ||
#' @family diagnostics | ||
#' @autoglobal | ||
#' @export | ||
epidist_diagnostics <- function(fit) { | ||
if (!inherits(fit, "epidist_fit")) { | ||
cli::cli_abort(c( | ||
"!" = "Diagnostics only supported for objects of class epidist_fit" | ||
)) | ||
} | ||
if (fit$algorithm %in% c("laplace", "meanfield", "fullrank", "pathfinder")) { | ||
cli::cli_abort(c( | ||
"!" = paste0( | ||
"Diagnostics not yet supported for the algorithm: ", fit$algorithm | ||
) | ||
)) | ||
} | ||
if (fit$algorithm == "sampling") { | ||
np <- brms::nuts_params(fit) | ||
divergent_indices <- np$Parameter == "divergent__" | ||
treedepth_indices <- np$Parameter == "treedepth__" | ||
diagnostics <- data.table( | ||
"time" = sum(rstan::get_elapsed_time(fit$fit)), | ||
"samples" = nrow(np) / length(unique(np$Parameter)), | ||
"max_rhat" = round(max(brms::rhat(fit)), 3), | ||
"divergent_transitions" = sum(np[divergent_indices, ]$Value), | ||
"per_divergent_transitions" = mean(np[divergent_indices, ]$Value), | ||
"max_treedepth" = max(np[treedepth_indices, ]$Value) | ||
) | ||
diagnostics[, no_at_max_treedepth := | ||
sum(np[treedepth_indices, ]$Value == max_treedepth)] | ||
diagnostics[, per_at_max_treedepth := no_at_max_treedepth / samples] | ||
} else { | ||
cli::cli_abort(c( | ||
"!" = paste0("Unrecognised algorithm: ", fit$algorithm) | ||
)) | ||
} | ||
return(diagnostics) | ||
} |
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
This file was deleted.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,12 +1,7 @@ | ||
expect_convergence <- function(fit, per_dts = 0.05, treedepth = 10, | ||
rhat = 1.05) { | ||
np <- brms::nuts_params(fit) | ||
divergent_indices <- np$Parameter == "divergent__" | ||
per_divergent_transitions <- mean(np[divergent_indices, ]$Value) | ||
treedepth_indices <- np$Parameter == "treedepth__" | ||
max_treedepth <- max(np[treedepth_indices, ]$Value) | ||
max_rhat <- max(brms::rhat(fit)) | ||
testthat::expect_lt(per_divergent_transitions, per_dts) | ||
testthat::expect_lt(max_treedepth, treedepth) | ||
testthat::expect_lt(max_rhat, rhat) | ||
diag <- epidist_diagnostics(fit) | ||
testthat::expect_lt(diag$per_divergent_transitions, per_dts) | ||
testthat::expect_lt(diag$max_treedepth, treedepth) | ||
testthat::expect_lt(diag$max_rhat, rhat) | ||
} |