Skip to content

Commit

Permalink
Issue #246: dplyr over data.table (#329)
Browse files Browse the repository at this point in the history
* Remove data.table from simulate.R

* Remove data.table from postprocess.R

* Remove data.table from observe.R aside from filter_obs_by_ptime

* Remove data.table from roxygen in latent_individual.R

* Rewrite as_latent_individual to use dplyr

* Rebase

* Use dplyr to subsample

* Using dplyr in epidist vignette

* Need to round here!

* Use dplyr in epidist_diagnostics

* Use dplyr in filter_obs_by_ptime

* Altering creation of latent_individual to correctly work with dplyr (and uncovering bugs here that existed before)

* Working through getting tests to pass

* Two fixes to epidist vignette

* Update FAQ vignette away from dt

* Update approximate inference vignette

* Remove call to arrange which creates bug (add issue for this)

* Remove data.table from Imports

* Rebase

* Rebase

* Lint

* Fix to logo and improve a little

* data.frame not data.table

* Hexsticker fixes

* Don't library data.table

* Removing final uses of data.table using find in files...

* Import runif

* Remove another mention of data.table

* Use dplyr in cmdstan check

* Need to use , with data.frame

* Fix to index being a factor issue

* Perhaps this is needed?

* Typo

* Regenerate globals and namespace

* Remove excess imports in line with R packages (2e) recommendations

* Need higher version of R for data importing

* Missed some stats:: qualifiers

* Remove old test code

* Somethings break when I don't import all of brms (because I'm not importing Stan). I think the saved data

* Run document

* Skip tests with fit on CRAN

* Revert import all of brms

* Resize logo
  • Loading branch information
athowes authored Sep 18, 2024
1 parent 214d8ae commit 4a7066e
Show file tree
Hide file tree
Showing 36 changed files with 240 additions and 326 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/check-cmdstan.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ jobs:

- name: Compile model and check syntax
run: |
dummy_obs <- data.table::data.table(case = 1L, ptime = 1, stime = 2,
dummy_obs <- dplyr::tibble(case = 1L, ptime = 1, stime = 2,
delay_daily = 1, delay_lwr = 1, delay_upr = 2, ptime_lwr = 1,
ptime_upr = 2, stime_lwr = 1, stime_upr = 2, obs_at = 100,
censored = "interval", censored_obs_time = 10, ptime_daily = 1,
Expand Down
4 changes: 1 addition & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,10 @@ URL: https://epidist.epinowcast.org/,
https://github.com/epinowcast/epidist/
BugReports: https://github.com/epinowcast/epidist/issues/
Depends:
R (>= 2.10)
R (>= 3.5.0)
Imports:
brms,
cmdstanr,
data.table,
ggplot2,
purrr,
stats,
Expand All @@ -54,7 +53,6 @@ Suggests:
patchwork
Remotes:
stan-dev/cmdstanr,
Rdatatable/data.table,
paul-buerkner/brms
Config/Needs/website:
r-lib/pkgdown,
Expand Down
25 changes: 3 additions & 22 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -37,29 +37,10 @@ export(simulate_exponential_cases)
export(simulate_gillespie)
export(simulate_secondary)
export(simulate_uniform_cases)
import(brms)
import(cmdstanr)
import(data.table)
import(ggplot2)
importFrom(brms,brmsterms)
importFrom(checkmate,assert_data_frame)
importFrom(checkmate,assert_int)
importFrom(checkmate,assert_names)
importFrom(checkmate,assert_numeric)
importFrom(cli,cli_abort)
importFrom(dplyr,all_of)
importFrom(brms,bf)
importFrom(brms,prior)
importFrom(dplyr,filter)
importFrom(dplyr,full_join)
importFrom(dplyr,mutate)
importFrom(dplyr,select)
importFrom(purrr,map_vec)
importFrom(rstan,lookup)
importFrom(stats,dgamma)
importFrom(stats,dlnorm)
importFrom(stats,pgamma)
importFrom(stats,plnorm)
importFrom(stats,rexp)
importFrom(stats,rgamma)
importFrom(stats,rlnorm)
importFrom(stats,runif)
importFrom(stats,update)
importFrom(utils,capture.output)
4 changes: 0 additions & 4 deletions R/defaults.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
#' @inheritParams epidist_validate
#' @param ... Additional arguments passed to method.
#' @family defaults
#' @importFrom cli cli_abort
#' @export
epidist_validate.default <- function(data, ...) {
cli::cli_abort(
Expand All @@ -17,7 +16,6 @@ epidist_validate.default <- function(data, ...) {
#' @inheritParams epidist_formula
#' @param ... Additional arguments passed to method.
#' @family defaults
#' @importFrom cli cli_abort
#' @export
epidist_formula.default <- function(data, ...) {
cli::cli_abort(
Expand All @@ -31,7 +29,6 @@ epidist_formula.default <- function(data, ...) {
#' @inheritParams epidist_family
#' @param ... Additional arguments passed to method.
#' @family defaults
#' @importFrom cli cli_abort
#' @export
epidist_family.default <- function(data, ...) {
cli::cli_abort(
Expand All @@ -45,7 +42,6 @@ epidist_family.default <- function(data, ...) {
#' @inheritParams epidist_stancode
#' @param ... Additional arguments passed to method.
#' @family defaults
#' @importFrom cli cli_abort
#' @export
epidist_stancode.default <- function(data, ...) {
cli::cli_abort(
Expand Down
29 changes: 15 additions & 14 deletions R/diagnostics.R
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
#'
#' This function computes diagnostics to assess the quality of a fitted model.
#' When the fitting algorithm used is `"sampling"` (HMC) then the output of
#' `epidist_diagnostics` is a `data.table` containing:
#' `epidist_diagnostics` is a `data.frame` containing:
#' * `time`: the total time taken to fit all chains
#' * `samples`: the total number of samples across all chains
#' * `max_rhat`: the highest value of the Gelman-Rubin statistic
Expand Down Expand Up @@ -35,19 +35,20 @@ epidist_diagnostics <- function(fit) {
}
if (fit$algorithm == "sampling") {
np <- brms::nuts_params(fit)
divergent_indices <- np$Parameter == "divergent__"
treedepth_indices <- np$Parameter == "treedepth__"
diagnostics <- data.table(
"time" = sum(rstan::get_elapsed_time(fit$fit)),
"samples" = nrow(np) / length(unique(np$Parameter)),
"max_rhat" = round(max(brms::rhat(fit), na.rm = TRUE), 3),
"divergent_transitions" = sum(np[divergent_indices, ]$Value),
"per_divergent_transitions" = mean(np[divergent_indices, ]$Value),
"max_treedepth" = max(np[treedepth_indices, ]$Value)
)
diagnostics[, no_at_max_treedepth :=
sum(np[treedepth_indices, ]$Value == max_treedepth)]
diagnostics[, per_at_max_treedepth := no_at_max_treedepth / samples]
divergent_ind <- np$Parameter == "divergent__"
treedepth_ind <- np$Parameter == "treedepth__"
diagnostics <- dplyr::tibble(
time = sum(rstan::get_elapsed_time(fit$fit)),
samples = nrow(np) / length(unique(np$Parameter)),
max_rhat = round(max(brms::rhat(fit), na.rm = TRUE), 3),
divergent_transitions = sum(np[divergent_ind, ]$Value),
per_divergent_transitions = mean(np[divergent_ind, ]$Value),
max_treedepth = max(np[treedepth_ind, ]$Value)
) |>
mutate(
no_at_max_treedepth = sum(np[treedepth_ind, ]$Value == max_treedepth),
per_at_max_treedepth = no_at_max_treedepth / samples
)
} else {
cli::cli_abort(c(
"!" = paste0("Unrecognised algorithm: ", fit$algorithm)
Expand Down
4 changes: 2 additions & 2 deletions R/epidist-package.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@

#' @import ggplot2
#' @import cmdstanr
#' @import brms

## usethis namespace: start
#' @import data.table
#' @importFrom dplyr filter select
#' @importFrom brms bf prior
## usethis namespace: end
NULL
34 changes: 7 additions & 27 deletions R/globals.R
Original file line number Diff line number Diff line change
@@ -1,57 +1,37 @@
# Generated by roxyglobals: do not edit by hand

utils::globalVariables(c(
"no_at_max_treedepth", # <epidist_diagnostics>
"max_treedepth", # <epidist_diagnostics>
"per_at_max_treedepth", # <epidist_diagnostics>
"no_at_max_treedepth", # <epidist_diagnostics>
"samples", # <epidist_diagnostics>
"id", # <as_latent_individual.data.frame>
"obs_t", # <as_latent_individual.data.frame>
"obs_at", # <as_latent_individual.data.frame>
"ptime_lwr", # <as_latent_individual.data.frame>
"pwindow", # <as_latent_individual.data.frame>
"stime_lwr", # <as_latent_individual.data.frame>
"ptime_upr", # <as_latent_individual.data.frame>
"stime_upr", # <as_latent_individual.data.frame>
"woverlap", # <as_latent_individual.data.frame>
"swindow", # <as_latent_individual.data.frame>
"delay", # <as_latent_individual.data.frame>
"row_id", # <as_latent_individual.data.frame>
"woverlap", # <epidist_stancode.epidist_latent_individual>
"row_id", # <epidist_stancode.epidist_latent_individual>
"ptime_daily", # <observe_process>
"ptime", # <observe_process>
"ptime_lwr", # <observe_process>
"ptime_upr", # <observe_process>
"stime_daily", # <observe_process>
"ptime_daily", # <observe_process>
"stime", # <observe_process>
"stime_lwr", # <observe_process>
"stime_upr", # <observe_process>
"stime_daily", # <observe_process>
"delay_daily", # <observe_process>
"delay_lwr", # <observe_process>
"delay_upr", # <observe_process>
"obs_at", # <observe_process>
"obs_at", # <filter_obs_by_obs_time>
"ptime", # <filter_obs_by_obs_time>
"censored_obs_time", # <filter_obs_by_obs_time>
"obs_at", # <filter_obs_by_obs_time>
"ptime_lwr", # <filter_obs_by_obs_time>
"censored", # <filter_obs_by_obs_time>
"stime_upr", # <filter_obs_by_obs_time>
"censored", # <filter_obs_by_ptime>
"ptime_upr", # <filter_obs_by_ptime>
"stime_upr", # <filter_obs_by_ptime>
":=", # <filter_obs_by_ptime>
"ptime", # <filter_obs_by_ptime>
"censored_obs_time", # <filter_obs_by_ptime>
"ptime_lwr", # <filter_obs_by_ptime>
"mu", # <add_mean_sd.lognormal_samples>
"sigma", # <add_mean_sd.lognormal_samples>
"sd", # <add_mean_sd.lognormal_samples>
"mu", # <add_mean_sd.gamma_samples>
"sd", # <add_mean_sd.gamma_samples>
"shape", # <add_mean_sd.gamma_samples>
"delay", # <simulate_secondary>
"stime", # <simulate_secondary>
"rlnorm", # <simulate_secondary>
"ptime", # <simulate_secondary>
"delay", # <simulate_secondary>
"prior_old", # <.replace_prior>
"prior_new", # <.replace_prior>
"source_new", # <.replace_prior>
Expand Down
16 changes: 8 additions & 8 deletions R/latent_gamma.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
#' @param prep The result of a call to [brms::posterior_predict()]
#' @param ... Additional arguments
#' @autoglobal
#' @importFrom stats rgamma
#' @keywords internal
posterior_predict_latent_gamma <- function(i, prep, ...) { # nolint: object_length_linter
mu <- brms::get_dpar(prep, "mu", i = i)
Expand All @@ -20,8 +19,8 @@ posterior_predict_latent_gamma <- function(i, prep, ...) { # nolint: object_leng
d_censored <- obs_t + 1
# while loop to impose the truncation
while (d_censored > obs_t) {
p_latent <- runif(1, 0, 1) * pwindow
d_latent <- rgamma(1, shape = shape[s], scale = mu[s] / shape[s])
p_latent <- stats::runif(1, 0, 1) * pwindow
d_latent <- stats::rgamma(1, shape = shape[s], scale = mu[s] / shape[s])
s_latent <- p_latent + d_latent
p_censored <- .floor_mult(p_latent, pwindow)
s_censored <- .floor_mult(s_latent, swindow)
Expand Down Expand Up @@ -53,7 +52,6 @@ posterior_epred_latent_gamma <- function(prep) { # nolint: object_length_linter
#' @param i The index of the observation to calculate the log likelihood of
#' @param prep The result of a call to [brms::prepare_predictions()]
#' @autoglobal
#' @importFrom stats dgamma pgamma
#' @keywords internal
log_lik_latent_gamma <- function(i, prep) {
mu <- brms::get_dpar(prep, "mu", i = i)
Expand All @@ -63,8 +61,8 @@ log_lik_latent_gamma <- function(i, prep) {
pwindow <- prep$data$vreal2[i]
swindow <- prep$data$vreal3[i]

swindow_raw <- runif(prep$ndraws)
pwindow_raw <- runif(prep$ndraws)
swindow_raw <- stats::runif(prep$ndraws)
pwindow_raw <- stats::runif(prep$ndraws)

swindow <- swindow_raw * swindow

Expand All @@ -77,7 +75,9 @@ log_lik_latent_gamma <- function(i, prep) {

d <- y - pwindow + swindow
obs_time <- obs_t - pwindow
lpdf <- dgamma(d, shape = shape, scale = mu / shape, log = TRUE)
lcdf <- pgamma(obs_time, shape = shape, scale = mu / shape, log.p = TRUE)
lpdf <- stats::dgamma(d, shape = shape, scale = mu / shape, log = TRUE)
lcdf <- stats::pgamma(
obs_time, shape = shape, scale = mu / shape, log.p = TRUE
)
return(lpdf - lcdf)
}
Loading

0 comments on commit 4a7066e

Please sign in to comment.