Skip to content

Commit

Permalink
Progress on getting diagnostics into a function
Browse files Browse the repository at this point in the history
  • Loading branch information
athowes committed Jul 17, 2024
1 parent feaea31 commit bb61daa
Show file tree
Hide file tree
Showing 24 changed files with 138 additions and 83 deletions.
5 changes: 4 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

S3method(as_latent_individual,data.frame)
S3method(epidist,default)
S3method(epidist_diagnostics,default)
S3method(epidist_diagnostics,epidist_fit)
S3method(epidist_family,default)
S3method(epidist_family,epidist_latent_individual)
S3method(epidist_formula,default)
Expand All @@ -21,6 +23,7 @@ export(correct_primary_censoring_bias)
export(draws_to_long)
export(drop_zero)
export(epidist)
export(epidist_diagnostics)
export(epidist_family)
export(epidist_formula)
export(epidist_prior)
Expand All @@ -45,7 +48,6 @@ export(plot_mean_posterior_pred)
export(plot_recovery)
export(plot_relative_recovery)
export(reverse_obs_at)
export(sample_model)
export(simulate_double_censored_pmf)
export(simulate_exponential_cases)
export(simulate_gillespie)
Expand All @@ -63,6 +65,7 @@ importFrom(checkmate,assert_int)
importFrom(checkmate,assert_names)
importFrom(checkmate,assert_numeric)
importFrom(posterior,as_draws_df)
importFrom(rstan,lookup)
importFrom(stats,as.formula)
importFrom(stats,ecdf)
importFrom(stats,integrate)
Expand Down
65 changes: 34 additions & 31 deletions R/diagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@
#' @param fit ...
#' @export
epidist_diagnostics <- function(fit) {
UseMethod("epidist")
UseMethod("epidist_diagnostics")
}

#' Default method for returning diagnostics
#'
#' @param fit ...
#' @family defaults
#' @method epidist_diagnostics default
#' @export
epidist_diagnostics.default <- function(fit) {
stop(
Expand All @@ -23,36 +24,38 @@ epidist_diagnostics.default <- function(fit) {
#'
#' @param fit ...
#' @family defaults
#' @method epidist_diagnostics epidist_fit
#' @export
epidist_diagnostics.default <- function(fit) {
stop(
"No epidist_diagnostics method implemented for the class ", class(data), "\n",
"See methods(epidist_diagnostics) for available methods"
)
}

#' Default method for returning diagnostics
#'
#' @param fit ...
#' @family defaults
#' @export
epidist_diagnostics <- function(fit) {
# diag <- fit$sampler_diagnostics(format = "df")
# diagnostics <- data.table(
# samples = nrow(diag),
# max_rhat = round(max(
# fit$summary(
# variables = NULL, posterior::rhat,
# .args = list(na.rm = TRUE)
# )$`posterior::rhat`,
# na.rm = TRUE
# ), 2),
# divergent_transitions = sum(diag$divergent__),
# per_divergent_transitions = sum(diag$divergent__) / nrow(diag),
# max_treedepth = max(diag$treedepth__)
# )
# diagnostics[, no_at_max_treedepth := sum(diag$treedepth__ == max_treedepth)]
# diagnostics[, per_at_max_treedepth := no_at_max_treedepth / nrow(diag)]
epidist_diagnostics.epidist_fit <- function(fit) {
if (fit$algorithm %in% c("laplace", "meanfield", "fullrank", "pathfinder")) {
cli::cli_abort(c(
"!" = paste0(
"Diagnostics not yet supported for the algorithm: ", fit$algorithm
)
))
} else if (!fit$algorithm == "sampling") {
cli::cli_abort(c(

Check warning on line 37 in R/diagnostics.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/diagnostics.R,line=37,col=22,[trailing_whitespace_linter] Trailing whitespace is superfluous.
"!" = paste0(
"Unrecognised algorithm: ", fit$algorithm
)
))
} else if (fit$algorithm == "sampling") {
np <- brms::nuts_params(fit)
divergent_indices <- np$Parameter == "divergent__"
treedepth_indices <- np$Parameter == "treedepth__"
diagnostics <- data.table(
"samples" = nrow(np) / length(unique(np$Parameter)),
"max_rhat" = round(max(brms::rhat(fit)), 3),
"divergent_transitions" = sum(np[divergent_indices, ]$Value),
"per_divergent_transitions" = mean(np[divergent_indices, ]$Value),
"max_treedepth" = max(np[treedepth_indices, ]$Value)
)
diagnostics[, no_at_max_treedepth :=

Check warning on line 53 in R/diagnostics.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/diagnostics.R,line=53,col=19,[object_usage_linter] no visible binding for global variable 'no_at_max_treedepth'
sum(np[treedepth_indices, ]$Value == max_treedepth)]

Check warning on line 54 in R/diagnostics.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/diagnostics.R,line=54,col=20,[indentation_linter] Indentation should be 18 spaces but is 20 spaces.

Check warning on line 54 in R/diagnostics.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/diagnostics.R,line=54,col=58,[object_usage_linter] no visible binding for global variable 'max_treedepth'
diagnostics[, per_at_max_treedepth := no_at_max_treedepth / samples]

Check warning on line 55 in R/diagnostics.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/diagnostics.R,line=55,col=19,[object_usage_linter] no visible binding for global variable 'per_at_max_treedepth'

Check warning on line 55 in R/diagnostics.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/diagnostics.R,line=55,col=43,[object_usage_linter] no visible binding for global variable 'no_at_max_treedepth'

Check warning on line 55 in R/diagnostics.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/diagnostics.R,line=55,col=65,[object_usage_linter] no visible binding for global variable 'samples'
}

Check warning on line 57 in R/diagnostics.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/diagnostics.R,line=57,col=1,[trailing_whitespace_linter] Trailing whitespace is superfluous.
# rstan::get_elapsed_time(fit$fit)

Check warning on line 58 in R/diagnostics.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/diagnostics.R,line=58,col=5,[commented_code_linter] Commented code should be removed.

# timing <- round(fit$time()$total, 1)
return(diagnostics)
}
5 changes: 0 additions & 5 deletions R/globals.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,6 @@
# Generated by roxyglobals: do not edit by hand

utils::globalVariables(c(
"error", # <sample_model>
"no_at_max_treedepth", # <sample_model>
"max_treedepth", # <sample_model>
"per_at_max_treedepth", # <sample_model>
"run_time", # <sample_model>
"meanlog", # <add_natural_scale_mean_sd>
"sdlog", # <add_natural_scale_mean_sd>
"sd", # <add_natural_scale_mean_sd>
Expand Down
1 change: 0 additions & 1 deletion man/add_natural_scale_mean_sd.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion man/correct_primary_censoring_bias.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion man/draws_to_long.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/epidist.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions man/epidist_diagnostics.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions man/epidist_diagnostics.default.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 24 additions & 0 deletions man/epidist_diagnostics.epidist_fit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/epidist_family.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/epidist_family.default.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/epidist_formula.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/epidist_formula.default.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/epidist_prior.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/epidist_prior.default.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions man/epidist_stancode.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/epidist_stancode.default.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion man/extract_lognormal_draws.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion man/make_relative_to_truth.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

39 changes: 0 additions & 39 deletions man/sample_model.Rd

This file was deleted.

1 change: 0 additions & 1 deletion man/summarise_draws.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion man/summarise_variable.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 16 additions & 0 deletions tests/testthat/test-unit-diagnostics.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
test_that("", { # nolint: line_length_linter.
skip_on_cran()
set.seed(1)
prep_obs <- as_latent_individual(sim_obs)
fit_laplace <- epidist(data = prep_obs, seed = 1, algorithm = "laplace")
diag <- epidist_diagnostics(fit_laplace)
})


test_that("", { # nolint: line_length_linter.
skip_on_cran()
set.seed(1)
prep_obs <- as_latent_individual(sim_obs)
fit <- epidist(data = prep_obs, seed = 1)
diag <- epidist_diagnostics(fit)
})

0 comments on commit bb61daa

Please sign in to comment.