Skip to content

Commit

Permalink
Issue 163: Create epidist_diagnostics (#175)
Browse files Browse the repository at this point in the history
* Remove sample_model

* Start to write epidist_diagnostics

* Rebase merged PRs

* Don't need this to be S3

* Finish removal of S3 and improve tests

* Replace custom code with diagnostics function

* Lint and documentation fixes

* Add documentation to epidist_diagnostics

* Lint and add diagnostics as section

* Add more unit tests for diagnostics

* expect_numeric doesn't exist

* Make tests softer
  • Loading branch information
athowes authored Jul 19, 2024
1 parent 7584099 commit 55ab8bf
Show file tree
Hide file tree
Showing 15 changed files with 140 additions and 138 deletions.
4 changes: 3 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ export(construct_cases_by_obs_window)
export(draws_to_long)
export(drop_zero)
export(epidist)
export(epidist_diagnostics)
export(epidist_family)
export(epidist_formula)
export(epidist_prior)
Expand All @@ -47,7 +48,6 @@ export(plot_mean_posterior_pred)
export(plot_recovery)
export(plot_relative_recovery)
export(reverse_obs_at)
export(sample_model)
export(simulate_double_censored_pmf)
export(simulate_exponential_cases)
export(simulate_gillespie)
Expand All @@ -63,6 +63,8 @@ importFrom(checkmate,assert_data_frame)
importFrom(checkmate,assert_int)
importFrom(checkmate,assert_names)
importFrom(checkmate,assert_numeric)
importFrom(cli,cli_abort)
importFrom(cli,cli_inform)
importFrom(posterior,as_draws_df)
importFrom(rstan,lookup)
importFrom(stats,as.formula)
Expand Down
59 changes: 59 additions & 0 deletions R/diagnostics.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
#' Diagnostics for `epidist_fit` models
#'
#' This function computes diagnostics to assess the quality of a fitted model.
#' When the fitting algorithm used is `"sampling"` (HMC) then the output of
#' `epidist_diagnostics` is a `data.table` containing:
#'
#' * `time`: the total time taken to fit all chains
#' * `samples`: the total number of samples across all chains
#' * `max_rhat`: the highest value of the Gelman-Rubin statistic
#' * `divergent_transitions`: the total number of divergent transitions
#' * `per_divergent_transitions`: the proportion of samples which had divergent
#' transitions
#' * `max_treedepth`: the highest value of the treedepth HMC parameter
#' * `no_at_max_treedepth`: the number of samples which attained the
#' `max_treedepth`
#' * `per_at_max_treedepth`: the proportion of samples which attained the
#' `max_treedepth`
#'
#' When the fitting algorithm is not `"sampling"` (see `?brms::brm` for other
#' possible algorithms) then diagnostics are yet to be implemented.
#' @param fit A fitted model of class `epidist_fit`
#' @family diagnostics
#' @autoglobal
#' @export
epidist_diagnostics <- function(fit) {
if (!inherits(fit, "epidist_fit")) {
cli::cli_abort(c(
"!" = "Diagnostics only supported for objects of class epidist_fit"
))
}
if (fit$algorithm %in% c("laplace", "meanfield", "fullrank", "pathfinder")) {
cli::cli_abort(c(
"!" = paste0(
"Diagnostics not yet supported for the algorithm: ", fit$algorithm
)
))
}
if (fit$algorithm == "sampling") {
np <- brms::nuts_params(fit)
divergent_indices <- np$Parameter == "divergent__"
treedepth_indices <- np$Parameter == "treedepth__"
diagnostics <- data.table(
"time" = sum(rstan::get_elapsed_time(fit$fit)),
"samples" = nrow(np) / length(unique(np$Parameter)),
"max_rhat" = round(max(brms::rhat(fit)), 3),
"divergent_transitions" = sum(np[divergent_indices, ]$Value),
"per_divergent_transitions" = mean(np[divergent_indices, ]$Value),
"max_treedepth" = max(np[treedepth_indices, ]$Value)
)
diagnostics[, no_at_max_treedepth :=
sum(np[treedepth_indices, ]$Value == max_treedepth)]
diagnostics[, per_at_max_treedepth := no_at_max_treedepth / samples]
} else {
cli::cli_abort(c(
"!" = paste0("Unrecognised algorithm: ", fit$algorithm)
))
}
return(diagnostics)
}
56 changes: 0 additions & 56 deletions R/fitting-and-postprocessing.R
Original file line number Diff line number Diff line change
@@ -1,59 +1,3 @@
#' Sample from the posterior of a model with additional diagnositics
#'
#' @param model ...
#' @param data ...
#' @param scenario ...
#' @param diagnostics ...
#' @param ... ...
#' @family postprocess
#' @autoglobal
#' @export
sample_model <- function(model, data, scenario = data.table::data.table(id = 1),
diagnostics = TRUE, ...) {
out <- data.table::copy(scenario)

# Setup failure tolerant model fitting
fit_model <- function(model, data, ...) {
fit <- cmdstanr::cmdstan_model(model)$sample(data = data, ...)
print(fit)
return(fit)
}
safe_fit_model <- purrr::safely(fit_model)
fit <- safe_fit_model(model, data, ...)

if (!is.null(fit$error)) {
out[, error := list(fit$error[[1]])]
diagnostics <- FALSE
}else {
out[, fit := list(fit$result)]
fit <- fit$result
}

if (diagnostics) {
diag <- fit$sampler_diagnostics(format = "df")
diagnostics <- data.table(
samples = nrow(diag),
max_rhat = round(max(
fit$summary(
variables = NULL, posterior::rhat,
.args = list(na.rm = TRUE)
)$`posterior::rhat`,
na.rm = TRUE
), 2),
divergent_transitions = sum(diag$divergent__),
per_divergent_transitions = sum(diag$divergent__) / nrow(diag),
max_treedepth = max(diag$treedepth__)
)
diagnostics[, no_at_max_treedepth := sum(diag$treedepth__ == max_treedepth)]
diagnostics[, per_at_max_treedepth := no_at_max_treedepth / nrow(diag)]
out <- cbind(out, diagnostics)

timing <- round(fit$time()$total, 1)
out[, run_time := timing]
}
return(out[])
}

#' Add natural scale summary parameters for a lognormal distribution
#'
#' @param dt ...
Expand Down
32 changes: 4 additions & 28 deletions R/globals.R
Original file line number Diff line number Diff line change
@@ -1,30 +1,22 @@
# Generated by roxyglobals: do not edit by hand

utils::globalVariables(c(
":=", # <sample_model>
"error", # <sample_model>
"no_at_max_treedepth", # <sample_model>
"max_treedepth", # <sample_model>
"per_at_max_treedepth", # <sample_model>
"run_time", # <sample_model>
":=", # <add_natural_scale_mean_sd>
"no_at_max_treedepth", # <epidist_diagnostics>
"max_treedepth", # <epidist_diagnostics>
"per_at_max_treedepth", # <epidist_diagnostics>
"samples", # <epidist_diagnostics>
"meanlog", # <add_natural_scale_mean_sd>
"sdlog", # <add_natural_scale_mean_sd>
"sd", # <add_natural_scale_mean_sd>
":=", # <extract_lognormal_draws>
"sdlog", # <extract_lognormal_draws>
"sdlog_log", # <extract_lognormal_draws>
"meanlog", # <extract_lognormal_draws>
"id", # <extract_lognormal_draws>
":=", # <make_relative_to_truth>
"true_value", # <make_relative_to_truth>
"value", # <make_relative_to_truth>
"rel_value", # <make_relative_to_truth>
":=", # <summarise_variable>
"value", # <summarise_variable>
":=", # <as_latent_individual.data.frame>
"id", # <as_latent_individual.data.frame>
".N", # <as_latent_individual.data.frame>
"obs_t", # <as_latent_individual.data.frame>
"obs_at", # <as_latent_individual.data.frame>
"ptime_lwr", # <as_latent_individual.data.frame>
Expand All @@ -38,7 +30,6 @@ utils::globalVariables(c(
"row_id", # <as_latent_individual.data.frame>
"woverlap", # <epidist_stancode.epidist_latent_individual>
"row_id", # <epidist_stancode.epidist_latent_individual>
":=", # <observe_process>
"ptime_daily", # <observe_process>
"ptime", # <observe_process>
"ptime_lwr", # <observe_process>
Expand All @@ -51,25 +42,21 @@ utils::globalVariables(c(
"delay_lwr", # <observe_process>
"delay_upr", # <observe_process>
"obs_at", # <observe_process>
":=", # <filter_obs_by_obs_time>
"obs_at", # <filter_obs_by_obs_time>
"ptime", # <filter_obs_by_obs_time>
"censored_obs_time", # <filter_obs_by_obs_time>
"ptime_lwr", # <filter_obs_by_obs_time>
"censored", # <filter_obs_by_obs_time>
"stime_upr", # <filter_obs_by_obs_time>
":=", # <filter_obs_by_ptime>
"censored", # <filter_obs_by_ptime>
"ptime_upr", # <filter_obs_by_ptime>
"stime_upr", # <filter_obs_by_ptime>
"ptime", # <filter_obs_by_ptime>
"censored_obs_time", # <filter_obs_by_ptime>
"ptime_lwr", # <filter_obs_by_ptime>
"censored_obs_time", # <pad_zero>
":=", # <pad_zero>
"delay_lwr", # <pad_zero>
"delay_daily", # <pad_zero>
":=", # <plot_relative_recovery>
"value", # <plot_relative_recovery>
"rel_value", # <plot_relative_recovery>
"case_type", # <plot_cases_by_obs_window>
Expand All @@ -85,23 +72,15 @@ utils::globalVariables(c(
"ptime_daily", # <plot_mean_posterior_pred>
"n", # <plot_mean_posterior_pred>
"obs_horizon", # <plot_mean_posterior_pred>
".data", # <plot_mean_posterior_pred>
":=", # <linelist_to_counts>
"time", # <linelist_to_counts>
".N", # <linelist_to_counts>
":=", # <linelist_to_cases>
"primary", # <linelist_to_cases>
"secondary", # <linelist_to_cases>
":=", # <reverse_obs_at>
"obs_at", # <reverse_obs_at>
"stime", # <construct_cases_by_obs_window>
"ptime", # <construct_cases_by_obs_window>
":=", # <construct_cases_by_obs_window>
"case_type", # <construct_cases_by_obs_window>
":=", # <combine_obs>
"obs_at", # <combine_obs>
"stime_daily", # <combine_obs>
":=", # <calculate_censor_delay>
"ptime_delay", # <calculate_censor_delay>
"ptime", # <calculate_censor_delay>
"ptime_daily", # <calculate_censor_delay>
Expand All @@ -111,12 +90,9 @@ utils::globalVariables(c(
"stime_delay", # <calculate_censor_delay>
"stime", # <calculate_censor_delay>
"stime_daily", # <calculate_censor_delay>
".N", # <event_to_incidence>
"ptime_daily", # <event_to_incidence>
"rlnorm", # <simulate_secondary>
":=", # <simulate_secondary>
"delay", # <simulate_secondary>
".N", # <simulate_secondary>
"stime", # <simulate_secondary>
"ptime", # <simulate_secondary>
NULL
Expand Down
4 changes: 4 additions & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -51,3 +51,7 @@ reference:
- title: Utility functions
contents:
- has_concept("utils")
- title: Diagnostic functions
contents:
- has_concept("diagnostics")

1 change: 0 additions & 1 deletion man/add_natural_scale_mean_sd.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion man/draws_to_long.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

35 changes: 35 additions & 0 deletions man/epidist_diagnostics.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion man/extract_lognormal_draws.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion man/make_relative_to_truth.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 0 additions & 38 deletions man/sample_model.Rd

This file was deleted.

1 change: 0 additions & 1 deletion man/summarise_draws.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion man/summarise_variable.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 4 additions & 9 deletions tests/testthat/helper-expectations.R
Original file line number Diff line number Diff line change
@@ -1,12 +1,7 @@
expect_convergence <- function(fit, per_dts = 0.05, treedepth = 10,
rhat = 1.05) {
np <- brms::nuts_params(fit)
divergent_indices <- np$Parameter == "divergent__"
per_divergent_transitions <- mean(np[divergent_indices, ]$Value)
treedepth_indices <- np$Parameter == "treedepth__"
max_treedepth <- max(np[treedepth_indices, ]$Value)
max_rhat <- max(brms::rhat(fit))
testthat::expect_lt(per_divergent_transitions, per_dts)
testthat::expect_lt(max_treedepth, treedepth)
testthat::expect_lt(max_rhat, rhat)
diag <- epidist_diagnostics(fit)
testthat::expect_lt(diag$per_divergent_transitions, per_dts)
testthat::expect_lt(diag$max_treedepth, treedepth)
testthat::expect_lt(diag$max_rhat, rhat)
}
Loading

0 comments on commit 55ab8bf

Please sign in to comment.