From 55ab8bf12517dfb29d09f7251e8c25f827a6e3c7 Mon Sep 17 00:00:00 2001 From: Adam Howes Date: Fri, 19 Jul 2024 11:24:01 +0100 Subject: [PATCH] Issue 163: Create `epidist_diagnostics` (#175) * Remove sample_model * Start to write epidist_diagnostics * Rebase merged PRs * Don't need this to be S3 * Finish removal of S3 and improve tests * Replace custom code with diagnostics function * Lint and documentation fixes * Add documentation to epidist_diagnostics * Lint and add diagnostics as section * Add more unit tests for diagnostics * expect_numeric doesn't exist * Make tests softer --- NAMESPACE | 4 +- R/diagnostics.R | 59 ++++++++++++++++++++++++++ R/fitting-and-postprocessing.R | 56 ------------------------ R/globals.R | 32 ++------------ _pkgdown.yml | 4 ++ man/add_natural_scale_mean_sd.Rd | 1 - man/draws_to_long.Rd | 1 - man/epidist_diagnostics.Rd | 35 +++++++++++++++ man/extract_lognormal_draws.Rd | 1 - man/make_relative_to_truth.Rd | 1 - man/sample_model.Rd | 38 ----------------- man/summarise_draws.Rd | 1 - man/summarise_variable.Rd | 1 - tests/testthat/helper-expectations.R | 13 ++---- tests/testthat/test-unit-diagnostics.R | 31 ++++++++++++++ 15 files changed, 140 insertions(+), 138 deletions(-) create mode 100644 R/diagnostics.R create mode 100644 man/epidist_diagnostics.Rd delete mode 100644 man/sample_model.Rd create mode 100644 tests/testthat/test-unit-diagnostics.R diff --git a/NAMESPACE b/NAMESPACE index 373e41295..da305ab6c 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -22,6 +22,7 @@ export(construct_cases_by_obs_window) export(draws_to_long) export(drop_zero) export(epidist) +export(epidist_diagnostics) export(epidist_family) export(epidist_formula) export(epidist_prior) @@ -47,7 +48,6 @@ export(plot_mean_posterior_pred) export(plot_recovery) export(plot_relative_recovery) export(reverse_obs_at) -export(sample_model) export(simulate_double_censored_pmf) export(simulate_exponential_cases) export(simulate_gillespie) @@ -63,6 +63,8 @@ importFrom(checkmate,assert_data_frame) importFrom(checkmate,assert_int) importFrom(checkmate,assert_names) importFrom(checkmate,assert_numeric) +importFrom(cli,cli_abort) +importFrom(cli,cli_inform) importFrom(posterior,as_draws_df) importFrom(rstan,lookup) importFrom(stats,as.formula) diff --git a/R/diagnostics.R b/R/diagnostics.R new file mode 100644 index 000000000..0ffea953f --- /dev/null +++ b/R/diagnostics.R @@ -0,0 +1,59 @@ +#' Diagnostics for `epidist_fit` models +#' +#' This function computes diagnostics to assess the quality of a fitted model. +#' When the fitting algorithm used is `"sampling"` (HMC) then the output of +#' `epidist_diagnostics` is a `data.table` containing: +#' +#' * `time`: the total time taken to fit all chains +#' * `samples`: the total number of samples across all chains +#' * `max_rhat`: the highest value of the Gelman-Rubin statistic +#' * `divergent_transitions`: the total number of divergent transitions +#' * `per_divergent_transitions`: the proportion of samples which had divergent +#' transitions +#' * `max_treedepth`: the highest value of the treedepth HMC parameter +#' * `no_at_max_treedepth`: the number of samples which attained the +#' `max_treedepth` +#' * `per_at_max_treedepth`: the proportion of samples which attained the +#' `max_treedepth` +#' +#' When the fitting algorithm is not `"sampling"` (see `?brms::brm` for other +#' possible algorithms) then diagnostics are yet to be implemented. +#' @param fit A fitted model of class `epidist_fit` +#' @family diagnostics +#' @autoglobal +#' @export +epidist_diagnostics <- function(fit) { + if (!inherits(fit, "epidist_fit")) { + cli::cli_abort(c( + "!" = "Diagnostics only supported for objects of class epidist_fit" + )) + } + if (fit$algorithm %in% c("laplace", "meanfield", "fullrank", "pathfinder")) { + cli::cli_abort(c( + "!" = paste0( + "Diagnostics not yet supported for the algorithm: ", fit$algorithm + ) + )) + } + if (fit$algorithm == "sampling") { + np <- brms::nuts_params(fit) + divergent_indices <- np$Parameter == "divergent__" + treedepth_indices <- np$Parameter == "treedepth__" + diagnostics <- data.table( + "time" = sum(rstan::get_elapsed_time(fit$fit)), + "samples" = nrow(np) / length(unique(np$Parameter)), + "max_rhat" = round(max(brms::rhat(fit)), 3), + "divergent_transitions" = sum(np[divergent_indices, ]$Value), + "per_divergent_transitions" = mean(np[divergent_indices, ]$Value), + "max_treedepth" = max(np[treedepth_indices, ]$Value) + ) + diagnostics[, no_at_max_treedepth := + sum(np[treedepth_indices, ]$Value == max_treedepth)] + diagnostics[, per_at_max_treedepth := no_at_max_treedepth / samples] + } else { + cli::cli_abort(c( + "!" = paste0("Unrecognised algorithm: ", fit$algorithm) + )) + } + return(diagnostics) +} diff --git a/R/fitting-and-postprocessing.R b/R/fitting-and-postprocessing.R index 9ee948a24..17021aab4 100644 --- a/R/fitting-and-postprocessing.R +++ b/R/fitting-and-postprocessing.R @@ -1,59 +1,3 @@ -#' Sample from the posterior of a model with additional diagnositics -#' -#' @param model ... -#' @param data ... -#' @param scenario ... -#' @param diagnostics ... -#' @param ... ... -#' @family postprocess -#' @autoglobal -#' @export -sample_model <- function(model, data, scenario = data.table::data.table(id = 1), - diagnostics = TRUE, ...) { - out <- data.table::copy(scenario) - - # Setup failure tolerant model fitting - fit_model <- function(model, data, ...) { - fit <- cmdstanr::cmdstan_model(model)$sample(data = data, ...) - print(fit) - return(fit) - } - safe_fit_model <- purrr::safely(fit_model) - fit <- safe_fit_model(model, data, ...) - - if (!is.null(fit$error)) { - out[, error := list(fit$error[[1]])] - diagnostics <- FALSE - }else { - out[, fit := list(fit$result)] - fit <- fit$result - } - - if (diagnostics) { - diag <- fit$sampler_diagnostics(format = "df") - diagnostics <- data.table( - samples = nrow(diag), - max_rhat = round(max( - fit$summary( - variables = NULL, posterior::rhat, - .args = list(na.rm = TRUE) - )$`posterior::rhat`, - na.rm = TRUE - ), 2), - divergent_transitions = sum(diag$divergent__), - per_divergent_transitions = sum(diag$divergent__) / nrow(diag), - max_treedepth = max(diag$treedepth__) - ) - diagnostics[, no_at_max_treedepth := sum(diag$treedepth__ == max_treedepth)] - diagnostics[, per_at_max_treedepth := no_at_max_treedepth / nrow(diag)] - out <- cbind(out, diagnostics) - - timing <- round(fit$time()$total, 1) - out[, run_time := timing] - } - return(out[]) -} - #' Add natural scale summary parameters for a lognormal distribution #' #' @param dt ... diff --git a/R/globals.R b/R/globals.R index 88d1e6932..f961f4e21 100644 --- a/R/globals.R +++ b/R/globals.R @@ -1,30 +1,22 @@ # Generated by roxyglobals: do not edit by hand utils::globalVariables(c( - ":=", # - "error", # - "no_at_max_treedepth", # - "max_treedepth", # - "per_at_max_treedepth", # - "run_time", # - ":=", # + "no_at_max_treedepth", # + "max_treedepth", # + "per_at_max_treedepth", # + "samples", # "meanlog", # "sdlog", # "sd", # - ":=", # "sdlog", # "sdlog_log", # "meanlog", # "id", # - ":=", # "true_value", # "value", # "rel_value", # - ":=", # "value", # - ":=", # "id", # - ".N", # "obs_t", # "obs_at", # "ptime_lwr", # @@ -38,7 +30,6 @@ utils::globalVariables(c( "row_id", # "woverlap", # "row_id", # - ":=", # "ptime_daily", # "ptime", # "ptime_lwr", # @@ -51,14 +42,12 @@ utils::globalVariables(c( "delay_lwr", # "delay_upr", # "obs_at", # - ":=", # "obs_at", # "ptime", # "censored_obs_time", # "ptime_lwr", # "censored", # "stime_upr", # - ":=", # "censored", # "ptime_upr", # "stime_upr", # @@ -66,10 +55,8 @@ utils::globalVariables(c( "censored_obs_time", # "ptime_lwr", # "censored_obs_time", # - ":=", # "delay_lwr", # "delay_daily", # - ":=", # "value", # "rel_value", # "case_type", # @@ -85,23 +72,15 @@ utils::globalVariables(c( "ptime_daily", # "n", # "obs_horizon", # - ".data", # - ":=", # "time", # - ".N", # - ":=", # "primary", # "secondary", # - ":=", # "obs_at", # "stime", # "ptime", # - ":=", # "case_type", # - ":=", # "obs_at", # "stime_daily", # - ":=", # "ptime_delay", # "ptime", # "ptime_daily", # @@ -111,12 +90,9 @@ utils::globalVariables(c( "stime_delay", # "stime", # "stime_daily", # - ".N", # "ptime_daily", # "rlnorm", # - ":=", # "delay", # - ".N", # "stime", # "ptime", # NULL diff --git a/_pkgdown.yml b/_pkgdown.yml index b2db49a97..f0599ef37 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -51,3 +51,7 @@ reference: - title: Utility functions contents: - has_concept("utils") +- title: Diagnostic functions + contents: + - has_concept("diagnostics") + diff --git a/man/add_natural_scale_mean_sd.Rd b/man/add_natural_scale_mean_sd.Rd index 3e54240d4..d0754226d 100644 --- a/man/add_natural_scale_mean_sd.Rd +++ b/man/add_natural_scale_mean_sd.Rd @@ -17,7 +17,6 @@ Other postprocess: \code{\link{draws_to_long}()}, \code{\link{extract_lognormal_draws}()}, \code{\link{make_relative_to_truth}()}, -\code{\link{sample_model}()}, \code{\link{summarise_draws}()}, \code{\link{summarise_variable}()} } diff --git a/man/draws_to_long.Rd b/man/draws_to_long.Rd index 631e5adb6..f4b346ef9 100644 --- a/man/draws_to_long.Rd +++ b/man/draws_to_long.Rd @@ -17,7 +17,6 @@ Other postprocess: \code{\link{add_natural_scale_mean_sd}()}, \code{\link{extract_lognormal_draws}()}, \code{\link{make_relative_to_truth}()}, -\code{\link{sample_model}()}, \code{\link{summarise_draws}()}, \code{\link{summarise_variable}()} } diff --git a/man/epidist_diagnostics.Rd b/man/epidist_diagnostics.Rd new file mode 100644 index 000000000..4a44b5e8f --- /dev/null +++ b/man/epidist_diagnostics.Rd @@ -0,0 +1,35 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/diagnostics.R +\name{epidist_diagnostics} +\alias{epidist_diagnostics} +\title{Diagnostics for \code{epidist_fit} models} +\usage{ +epidist_diagnostics(fit) +} +\arguments{ +\item{fit}{A fitted model of class \code{epidist_fit}} +} +\description{ +This function computes diagnostics to assess the quality of a fitted model. +When the fitting algorithm used is \code{"sampling"} (HMC) then the output of +\code{epidist_diagnostics} is a \code{data.table} containing: +} +\details{ +\itemize{ +\item \code{time}: the total time taken to fit all chains +\item \code{samples}: the total number of samples across all chains +\item \code{max_rhat}: the highest value of the Gelman-Rubin statistic +\item \code{divergent_transitions}: the total number of divergent transitions +\item \code{per_divergent_transitions}: the proportion of samples which had divergent +transitions +\item \code{max_treedepth}: the highest value of the treedepth HMC parameter +\item \code{no_at_max_treedepth}: the number of samples which attained the +\code{max_treedepth} +\item \code{per_at_max_treedepth}: the proportion of samples which attained the +\code{max_treedepth} +} + +When the fitting algorithm is not \code{"sampling"} (see \code{?brms::brm} for other +possible algorithms) then diagnostics are yet to be implemented. +} +\concept{diagnostics} diff --git a/man/extract_lognormal_draws.Rd b/man/extract_lognormal_draws.Rd index 7b1a4c69f..d7270f874 100644 --- a/man/extract_lognormal_draws.Rd +++ b/man/extract_lognormal_draws.Rd @@ -21,7 +21,6 @@ Other postprocess: \code{\link{add_natural_scale_mean_sd}()}, \code{\link{draws_to_long}()}, \code{\link{make_relative_to_truth}()}, -\code{\link{sample_model}()}, \code{\link{summarise_draws}()}, \code{\link{summarise_variable}()} } diff --git a/man/make_relative_to_truth.Rd b/man/make_relative_to_truth.Rd index 446bec6de..97cd56564 100644 --- a/man/make_relative_to_truth.Rd +++ b/man/make_relative_to_truth.Rd @@ -21,7 +21,6 @@ Other postprocess: \code{\link{add_natural_scale_mean_sd}()}, \code{\link{draws_to_long}()}, \code{\link{extract_lognormal_draws}()}, -\code{\link{sample_model}()}, \code{\link{summarise_draws}()}, \code{\link{summarise_variable}()} } diff --git a/man/sample_model.Rd b/man/sample_model.Rd deleted file mode 100644 index 22506ef23..000000000 --- a/man/sample_model.Rd +++ /dev/null @@ -1,38 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/fitting-and-postprocessing.R -\name{sample_model} -\alias{sample_model} -\title{Sample from the posterior of a model with additional diagnositics} -\usage{ -sample_model( - model, - data, - scenario = data.table::data.table(id = 1), - diagnostics = TRUE, - ... -) -} -\arguments{ -\item{model}{...} - -\item{data}{...} - -\item{scenario}{...} - -\item{diagnostics}{...} - -\item{...}{...} -} -\description{ -Sample from the posterior of a model with additional diagnositics -} -\seealso{ -Other postprocess: -\code{\link{add_natural_scale_mean_sd}()}, -\code{\link{draws_to_long}()}, -\code{\link{extract_lognormal_draws}()}, -\code{\link{make_relative_to_truth}()}, -\code{\link{summarise_draws}()}, -\code{\link{summarise_variable}()} -} -\concept{postprocess} diff --git a/man/summarise_draws.Rd b/man/summarise_draws.Rd index 96f5e5b5e..d8d0b29a9 100644 --- a/man/summarise_draws.Rd +++ b/man/summarise_draws.Rd @@ -24,7 +24,6 @@ Other postprocess: \code{\link{draws_to_long}()}, \code{\link{extract_lognormal_draws}()}, \code{\link{make_relative_to_truth}()}, -\code{\link{sample_model}()}, \code{\link{summarise_variable}()} } \concept{postprocess} diff --git a/man/summarise_variable.Rd b/man/summarise_variable.Rd index 048a7cbac..c7a2dc30f 100644 --- a/man/summarise_variable.Rd +++ b/man/summarise_variable.Rd @@ -24,7 +24,6 @@ Other postprocess: \code{\link{draws_to_long}()}, \code{\link{extract_lognormal_draws}()}, \code{\link{make_relative_to_truth}()}, -\code{\link{sample_model}()}, \code{\link{summarise_draws}()} } \concept{postprocess} diff --git a/tests/testthat/helper-expectations.R b/tests/testthat/helper-expectations.R index 48721fa3f..ee8d3041e 100644 --- a/tests/testthat/helper-expectations.R +++ b/tests/testthat/helper-expectations.R @@ -1,12 +1,7 @@ expect_convergence <- function(fit, per_dts = 0.05, treedepth = 10, rhat = 1.05) { - np <- brms::nuts_params(fit) - divergent_indices <- np$Parameter == "divergent__" - per_divergent_transitions <- mean(np[divergent_indices, ]$Value) - treedepth_indices <- np$Parameter == "treedepth__" - max_treedepth <- max(np[treedepth_indices, ]$Value) - max_rhat <- max(brms::rhat(fit)) - testthat::expect_lt(per_divergent_transitions, per_dts) - testthat::expect_lt(max_treedepth, treedepth) - testthat::expect_lt(max_rhat, rhat) + diag <- epidist_diagnostics(fit) + testthat::expect_lt(diag$per_divergent_transitions, per_dts) + testthat::expect_lt(diag$max_treedepth, treedepth) + testthat::expect_lt(diag$max_rhat, rhat) } diff --git a/tests/testthat/test-unit-diagnostics.R b/tests/testthat/test-unit-diagnostics.R new file mode 100644 index 000000000..e8e57496a --- /dev/null +++ b/tests/testthat/test-unit-diagnostics.R @@ -0,0 +1,31 @@ +test_that("epidist_diagnostics", { # nolint: line_length_linter. + skip_on_cran() + set.seed(1) + prep_obs <- as_latent_individual(sim_obs) + fit <- epidist(data = prep_obs, seed = 1) + diag <- epidist_diagnostics(fit) + expected_names <- c( + "time", "samples", "max_rhat", "divergent_transitions", + "per_divergent_transitions", "max_treedepth", "no_at_max_treedepth", + "per_at_max_treedepth" + ) + expect_equal(names(diag), expected_names) + expect_gt(diag$time, 0) + expect_gt(diag$samples, 0) + expect_gt(diag$max_rhat, 0.9) + expect_lt(diag$max_rhat, 1.1) + expect_gte(diag$divergent_transitions, 0) + expect_lt(diag$divergent_transitions, diag$samples) + expect_lt(diag$max_treedepth, 12) + expect_lte(diag$no_at_max_treedepth, diag$samples) + expect_lte(diag$per_at_max_treedepth, 1) + expect_gt(diag$per_at_max_treedepth, 0) +}) + +test_that("epidist_diagnostics gives an error when passed model fit using the Laplace algorithm", { # nolint: line_length_linter. + skip_on_cran() + set.seed(1) + prep_obs <- as_latent_individual(sim_obs) + fit_laplace <- epidist(data = prep_obs, seed = 1, algorithm = "laplace") + expect_error(epidist_diagnostics(fit_laplace)) +})