From bb61daa73252f92fd6adbaef7e40828141fbb026 Mon Sep 17 00:00:00 2001 From: athowes Date: Wed, 17 Jul 2024 15:29:42 +0100 Subject: [PATCH] Progress on getting diagnostics into a function --- NAMESPACE | 5 +- R/diagnostics.R | 65 ++++++++++++++------------ R/globals.R | 5 -- man/add_natural_scale_mean_sd.Rd | 1 - man/correct_primary_censoring_bias.Rd | 1 - man/draws_to_long.Rd | 1 - man/epidist.Rd | 1 + man/epidist_diagnostics.Rd | 23 +++++++++ man/epidist_diagnostics.default.Rd | 24 ++++++++++ man/epidist_diagnostics.epidist_fit.Rd | 24 ++++++++++ man/epidist_family.Rd | 1 + man/epidist_family.default.Rd | 2 + man/epidist_formula.Rd | 1 + man/epidist_formula.default.Rd | 2 + man/epidist_prior.Rd | 1 + man/epidist_prior.default.Rd | 2 + man/epidist_stancode.Rd | 1 + man/epidist_stancode.default.Rd | 2 + man/extract_lognormal_draws.Rd | 1 - man/make_relative_to_truth.Rd | 1 - man/sample_model.Rd | 39 ---------------- man/summarise_draws.Rd | 1 - man/summarise_variable.Rd | 1 - tests/testthat/test-unit-diagnostics.R | 16 +++++++ 24 files changed, 138 insertions(+), 83 deletions(-) create mode 100644 man/epidist_diagnostics.Rd create mode 100644 man/epidist_diagnostics.default.Rd create mode 100644 man/epidist_diagnostics.epidist_fit.Rd delete mode 100644 man/sample_model.Rd create mode 100644 tests/testthat/test-unit-diagnostics.R diff --git a/NAMESPACE b/NAMESPACE index c04f42354..2bae6e492 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -2,6 +2,8 @@ S3method(as_latent_individual,data.frame) S3method(epidist,default) +S3method(epidist_diagnostics,default) +S3method(epidist_diagnostics,epidist_fit) S3method(epidist_family,default) S3method(epidist_family,epidist_latent_individual) S3method(epidist_formula,default) @@ -21,6 +23,7 @@ export(correct_primary_censoring_bias) export(draws_to_long) export(drop_zero) export(epidist) +export(epidist_diagnostics) export(epidist_family) export(epidist_formula) export(epidist_prior) @@ -45,7 +48,6 @@ export(plot_mean_posterior_pred) export(plot_recovery) export(plot_relative_recovery) export(reverse_obs_at) -export(sample_model) export(simulate_double_censored_pmf) export(simulate_exponential_cases) export(simulate_gillespie) @@ -63,6 +65,7 @@ importFrom(checkmate,assert_int) importFrom(checkmate,assert_names) importFrom(checkmate,assert_numeric) importFrom(posterior,as_draws_df) +importFrom(rstan,lookup) importFrom(stats,as.formula) importFrom(stats,ecdf) importFrom(stats,integrate) diff --git a/R/diagnostics.R b/R/diagnostics.R index 002e8881d..0620f8a44 100644 --- a/R/diagnostics.R +++ b/R/diagnostics.R @@ -4,13 +4,14 @@ #' @param fit ... #' @export epidist_diagnostics <- function(fit) { - UseMethod("epidist") + UseMethod("epidist_diagnostics") } #' Default method for returning diagnostics #' #' @param fit ... #' @family defaults +#' @method epidist_diagnostics default #' @export epidist_diagnostics.default <- function(fit) { stop( @@ -23,36 +24,38 @@ epidist_diagnostics.default <- function(fit) { #' #' @param fit ... #' @family defaults +#' @method epidist_diagnostics epidist_fit #' @export -epidist_diagnostics.default <- function(fit) { - stop( - "No epidist_diagnostics method implemented for the class ", class(data), "\n", - "See methods(epidist_diagnostics) for available methods" - ) -} - -#' Default method for returning diagnostics -#' -#' @param fit ... -#' @family defaults -#' @export -epidist_diagnostics <- function(fit) { - # diag <- fit$sampler_diagnostics(format = "df") - # diagnostics <- data.table( - # samples = nrow(diag), - # max_rhat = round(max( - # fit$summary( - # variables = NULL, posterior::rhat, - # .args = list(na.rm = TRUE) - # )$`posterior::rhat`, - # na.rm = TRUE - # ), 2), - # divergent_transitions = sum(diag$divergent__), - # per_divergent_transitions = sum(diag$divergent__) / nrow(diag), - # max_treedepth = max(diag$treedepth__) - # ) - # diagnostics[, no_at_max_treedepth := sum(diag$treedepth__ == max_treedepth)] - # diagnostics[, per_at_max_treedepth := no_at_max_treedepth / nrow(diag)] +epidist_diagnostics.epidist_fit <- function(fit) { + if (fit$algorithm %in% c("laplace", "meanfield", "fullrank", "pathfinder")) { + cli::cli_abort(c( + "!" = paste0( + "Diagnostics not yet supported for the algorithm: ", fit$algorithm + ) + )) + } else if (!fit$algorithm == "sampling") { + cli::cli_abort(c( + "!" = paste0( + "Unrecognised algorithm: ", fit$algorithm + ) + )) + } else if (fit$algorithm == "sampling") { + np <- brms::nuts_params(fit) + divergent_indices <- np$Parameter == "divergent__" + treedepth_indices <- np$Parameter == "treedepth__" + diagnostics <- data.table( + "samples" = nrow(np) / length(unique(np$Parameter)), + "max_rhat" = round(max(brms::rhat(fit)), 3), + "divergent_transitions" = sum(np[divergent_indices, ]$Value), + "per_divergent_transitions" = mean(np[divergent_indices, ]$Value), + "max_treedepth" = max(np[treedepth_indices, ]$Value) + ) + diagnostics[, no_at_max_treedepth := + sum(np[treedepth_indices, ]$Value == max_treedepth)] + diagnostics[, per_at_max_treedepth := no_at_max_treedepth / samples] + } + + # rstan::get_elapsed_time(fit$fit) - # timing <- round(fit$time()$total, 1) + return(diagnostics) } diff --git a/R/globals.R b/R/globals.R index b4a8d25ba..eb962514c 100644 --- a/R/globals.R +++ b/R/globals.R @@ -1,11 +1,6 @@ # Generated by roxyglobals: do not edit by hand utils::globalVariables(c( - "error", # - "no_at_max_treedepth", # - "max_treedepth", # - "per_at_max_treedepth", # - "run_time", # "meanlog", # "sdlog", # "sd", # diff --git a/man/add_natural_scale_mean_sd.Rd b/man/add_natural_scale_mean_sd.Rd index 269283932..bde478862 100644 --- a/man/add_natural_scale_mean_sd.Rd +++ b/man/add_natural_scale_mean_sd.Rd @@ -18,7 +18,6 @@ Other postprocess: \code{\link{draws_to_long}()}, \code{\link{extract_lognormal_draws}()}, \code{\link{make_relative_to_truth}()}, -\code{\link{sample_model}()}, \code{\link{summarise_draws}()}, \code{\link{summarise_variable}()} } diff --git a/man/correct_primary_censoring_bias.Rd b/man/correct_primary_censoring_bias.Rd index 1b167e343..2c64112cc 100644 --- a/man/correct_primary_censoring_bias.Rd +++ b/man/correct_primary_censoring_bias.Rd @@ -18,7 +18,6 @@ Other postprocess: \code{\link{draws_to_long}()}, \code{\link{extract_lognormal_draws}()}, \code{\link{make_relative_to_truth}()}, -\code{\link{sample_model}()}, \code{\link{summarise_draws}()}, \code{\link{summarise_variable}()} } diff --git a/man/draws_to_long.Rd b/man/draws_to_long.Rd index 0c64dfdaf..e5268fbfc 100644 --- a/man/draws_to_long.Rd +++ b/man/draws_to_long.Rd @@ -18,7 +18,6 @@ Other postprocess: \code{\link{correct_primary_censoring_bias}()}, \code{\link{extract_lognormal_draws}()}, \code{\link{make_relative_to_truth}()}, -\code{\link{sample_model}()}, \code{\link{summarise_draws}()}, \code{\link{summarise_variable}()} } diff --git a/man/epidist.Rd b/man/epidist.Rd index 3c4d9c8ea..889f7417f 100644 --- a/man/epidist.Rd +++ b/man/epidist.Rd @@ -25,6 +25,7 @@ Interface using \code{brms} } \seealso{ Other generics: +\code{\link{epidist_diagnostics}()}, \code{\link{epidist_family}()}, \code{\link{epidist_formula}()}, \code{\link{epidist_prior}()}, diff --git a/man/epidist_diagnostics.Rd b/man/epidist_diagnostics.Rd new file mode 100644 index 000000000..f05799bb6 --- /dev/null +++ b/man/epidist_diagnostics.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/diagnostics.R +\name{epidist_diagnostics} +\alias{epidist_diagnostics} +\title{Diagnostics for fitted epidist model} +\usage{ +epidist_diagnostics(fit) +} +\arguments{ +\item{fit}{...} +} +\description{ +Diagnostics for fitted epidist model +} +\seealso{ +Other generics: +\code{\link{epidist}()}, +\code{\link{epidist_family}()}, +\code{\link{epidist_formula}()}, +\code{\link{epidist_prior}()}, +\code{\link{epidist_stancode}()} +} +\concept{generics} diff --git a/man/epidist_diagnostics.default.Rd b/man/epidist_diagnostics.default.Rd new file mode 100644 index 000000000..48119131f --- /dev/null +++ b/man/epidist_diagnostics.default.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/diagnostics.R +\name{epidist_diagnostics.default} +\alias{epidist_diagnostics.default} +\title{Default method for returning diagnostics} +\usage{ +\method{epidist_diagnostics}{default}(fit) +} +\arguments{ +\item{fit}{...} +} +\description{ +Default method for returning diagnostics +} +\seealso{ +Other defaults: +\code{\link{epidist.default}()}, +\code{\link{epidist_diagnostics.epidist_fit}()}, +\code{\link{epidist_family.default}()}, +\code{\link{epidist_formula.default}()}, +\code{\link{epidist_prior.default}()}, +\code{\link{epidist_stancode.default}()} +} +\concept{defaults} diff --git a/man/epidist_diagnostics.epidist_fit.Rd b/man/epidist_diagnostics.epidist_fit.Rd new file mode 100644 index 000000000..3e8aeac65 --- /dev/null +++ b/man/epidist_diagnostics.epidist_fit.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/diagnostics.R +\name{epidist_diagnostics.epidist_fit} +\alias{epidist_diagnostics.epidist_fit} +\title{Default method for returning diagnostics} +\usage{ +\method{epidist_diagnostics}{epidist_fit}(fit) +} +\arguments{ +\item{fit}{...} +} +\description{ +Default method for returning diagnostics +} +\seealso{ +Other defaults: +\code{\link{epidist.default}()}, +\code{\link{epidist_diagnostics.default}()}, +\code{\link{epidist_family.default}()}, +\code{\link{epidist_formula.default}()}, +\code{\link{epidist_prior.default}()}, +\code{\link{epidist_stancode.default}()} +} +\concept{defaults} diff --git a/man/epidist_family.Rd b/man/epidist_family.Rd index f15d17f28..b771665ea 100644 --- a/man/epidist_family.Rd +++ b/man/epidist_family.Rd @@ -17,6 +17,7 @@ Define model specific family \seealso{ Other generics: \code{\link{epidist}()}, +\code{\link{epidist_diagnostics}()}, \code{\link{epidist_formula}()}, \code{\link{epidist_prior}()}, \code{\link{epidist_stancode}()} diff --git a/man/epidist_family.default.Rd b/man/epidist_family.default.Rd index 791cdf6fd..f95dd01e3 100644 --- a/man/epidist_family.default.Rd +++ b/man/epidist_family.default.Rd @@ -17,6 +17,8 @@ Default method for defining a model specific family \seealso{ Other defaults: \code{\link{epidist.default}()}, +\code{\link{epidist_diagnostics.default}()}, +\code{\link{epidist_diagnostics.epidist_fit}()}, \code{\link{epidist_formula.default}()}, \code{\link{epidist_prior.default}()}, \code{\link{epidist_stancode.default}()} diff --git a/man/epidist_formula.Rd b/man/epidist_formula.Rd index e900f5116..82819214e 100644 --- a/man/epidist_formula.Rd +++ b/man/epidist_formula.Rd @@ -17,6 +17,7 @@ Define a model specific formula \seealso{ Other generics: \code{\link{epidist}()}, +\code{\link{epidist_diagnostics}()}, \code{\link{epidist_family}()}, \code{\link{epidist_prior}()}, \code{\link{epidist_stancode}()} diff --git a/man/epidist_formula.default.Rd b/man/epidist_formula.default.Rd index ad5d57af1..cf0b143ea 100644 --- a/man/epidist_formula.default.Rd +++ b/man/epidist_formula.default.Rd @@ -17,6 +17,8 @@ Default method for defining a model specific formula \seealso{ Other defaults: \code{\link{epidist.default}()}, +\code{\link{epidist_diagnostics.default}()}, +\code{\link{epidist_diagnostics.epidist_fit}()}, \code{\link{epidist_family.default}()}, \code{\link{epidist_prior.default}()}, \code{\link{epidist_stancode.default}()} diff --git a/man/epidist_prior.Rd b/man/epidist_prior.Rd index 1dd39211a..9d71c056b 100644 --- a/man/epidist_prior.Rd +++ b/man/epidist_prior.Rd @@ -17,6 +17,7 @@ Define model specific priors \seealso{ Other generics: \code{\link{epidist}()}, +\code{\link{epidist_diagnostics}()}, \code{\link{epidist_family}()}, \code{\link{epidist_formula}()}, \code{\link{epidist_stancode}()} diff --git a/man/epidist_prior.default.Rd b/man/epidist_prior.default.Rd index 6f2c500c1..b8dfd16e7 100644 --- a/man/epidist_prior.default.Rd +++ b/man/epidist_prior.default.Rd @@ -17,6 +17,8 @@ Default method for defining model specific priors \seealso{ Other defaults: \code{\link{epidist.default}()}, +\code{\link{epidist_diagnostics.default}()}, +\code{\link{epidist_diagnostics.epidist_fit}()}, \code{\link{epidist_family.default}()}, \code{\link{epidist_formula.default}()}, \code{\link{epidist_stancode.default}()} diff --git a/man/epidist_stancode.Rd b/man/epidist_stancode.Rd index 8b8afe9da..640a7dee4 100644 --- a/man/epidist_stancode.Rd +++ b/man/epidist_stancode.Rd @@ -17,6 +17,7 @@ Define model specific Stan code \seealso{ Other generics: \code{\link{epidist}()}, +\code{\link{epidist_diagnostics}()}, \code{\link{epidist_family}()}, \code{\link{epidist_formula}()}, \code{\link{epidist_prior}()} diff --git a/man/epidist_stancode.default.Rd b/man/epidist_stancode.default.Rd index 0e448975f..2d39b9d3d 100644 --- a/man/epidist_stancode.default.Rd +++ b/man/epidist_stancode.default.Rd @@ -17,6 +17,8 @@ Default method for defining model specific Stan code \seealso{ Other defaults: \code{\link{epidist.default}()}, +\code{\link{epidist_diagnostics.default}()}, +\code{\link{epidist_diagnostics.epidist_fit}()}, \code{\link{epidist_family.default}()}, \code{\link{epidist_formula.default}()}, \code{\link{epidist_prior.default}()} diff --git a/man/extract_lognormal_draws.Rd b/man/extract_lognormal_draws.Rd index d154b6abb..ba590d0b8 100644 --- a/man/extract_lognormal_draws.Rd +++ b/man/extract_lognormal_draws.Rd @@ -22,7 +22,6 @@ Other postprocess: \code{\link{correct_primary_censoring_bias}()}, \code{\link{draws_to_long}()}, \code{\link{make_relative_to_truth}()}, -\code{\link{sample_model}()}, \code{\link{summarise_draws}()}, \code{\link{summarise_variable}()} } diff --git a/man/make_relative_to_truth.Rd b/man/make_relative_to_truth.Rd index 99573f12a..de0df0295 100644 --- a/man/make_relative_to_truth.Rd +++ b/man/make_relative_to_truth.Rd @@ -22,7 +22,6 @@ Other postprocess: \code{\link{correct_primary_censoring_bias}()}, \code{\link{draws_to_long}()}, \code{\link{extract_lognormal_draws}()}, -\code{\link{sample_model}()}, \code{\link{summarise_draws}()}, \code{\link{summarise_variable}()} } diff --git a/man/sample_model.Rd b/man/sample_model.Rd deleted file mode 100644 index fbb6a5e40..000000000 --- a/man/sample_model.Rd +++ /dev/null @@ -1,39 +0,0 @@ -% Generated by roxygen2: do not edit by hand -% Please edit documentation in R/fitting-and-postprocessing.R -\name{sample_model} -\alias{sample_model} -\title{Sample from the posterior of a model with additional diagnositics} -\usage{ -sample_model( - model, - data, - scenario = data.table::data.table(id = 1), - diagnostics = TRUE, - ... -) -} -\arguments{ -\item{model}{...} - -\item{data}{...} - -\item{scenario}{...} - -\item{diagnostics}{...} - -\item{...}{...} -} -\description{ -Sample from the posterior of a model with additional diagnositics -} -\seealso{ -Other postprocess: -\code{\link{add_natural_scale_mean_sd}()}, -\code{\link{correct_primary_censoring_bias}()}, -\code{\link{draws_to_long}()}, -\code{\link{extract_lognormal_draws}()}, -\code{\link{make_relative_to_truth}()}, -\code{\link{summarise_draws}()}, -\code{\link{summarise_variable}()} -} -\concept{postprocess} diff --git a/man/summarise_draws.Rd b/man/summarise_draws.Rd index 617082e01..d678fcdf9 100644 --- a/man/summarise_draws.Rd +++ b/man/summarise_draws.Rd @@ -25,7 +25,6 @@ Other postprocess: \code{\link{draws_to_long}()}, \code{\link{extract_lognormal_draws}()}, \code{\link{make_relative_to_truth}()}, -\code{\link{sample_model}()}, \code{\link{summarise_variable}()} } \concept{postprocess} diff --git a/man/summarise_variable.Rd b/man/summarise_variable.Rd index 955897398..207075edf 100644 --- a/man/summarise_variable.Rd +++ b/man/summarise_variable.Rd @@ -25,7 +25,6 @@ Other postprocess: \code{\link{draws_to_long}()}, \code{\link{extract_lognormal_draws}()}, \code{\link{make_relative_to_truth}()}, -\code{\link{sample_model}()}, \code{\link{summarise_draws}()} } \concept{postprocess} diff --git a/tests/testthat/test-unit-diagnostics.R b/tests/testthat/test-unit-diagnostics.R new file mode 100644 index 000000000..d1c06e7c7 --- /dev/null +++ b/tests/testthat/test-unit-diagnostics.R @@ -0,0 +1,16 @@ +test_that("", { # nolint: line_length_linter. + skip_on_cran() + set.seed(1) + prep_obs <- as_latent_individual(sim_obs) + fit_laplace <- epidist(data = prep_obs, seed = 1, algorithm = "laplace") + diag <- epidist_diagnostics(fit_laplace) +}) + + +test_that("", { # nolint: line_length_linter. + skip_on_cran() + set.seed(1) + prep_obs <- as_latent_individual(sim_obs) + fit <- epidist(data = prep_obs, seed = 1) + diag <- epidist_diagnostics(fit) +}) \ No newline at end of file