diff --git a/.github/workflows/check-cmdstan.yaml b/.github/workflows/check-cmdstan.yaml index 57a61532c..cd6f35e43 100644 --- a/.github/workflows/check-cmdstan.yaml +++ b/.github/workflows/check-cmdstan.yaml @@ -59,7 +59,7 @@ jobs: - name: Compile model and check syntax run: | - dummy_obs <- data.table::data.table(case = 1L, ptime = 1, stime = 2, + dummy_obs <- dplyr::tibble(case = 1L, ptime = 1, stime = 2, delay_daily = 1, delay_lwr = 1, delay_upr = 2, ptime_lwr = 1, ptime_upr = 2, stime_lwr = 1, stime_upr = 2, obs_at = 100, censored = "interval", censored_obs_time = 10, ptime_daily = 1, diff --git a/DESCRIPTION b/DESCRIPTION index ddcf85fd5..5dd3aebcf 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -23,11 +23,10 @@ URL: https://epidist.epinowcast.org/, https://github.com/epinowcast/epidist/ BugReports: https://github.com/epinowcast/epidist/issues/ Depends: - R (>= 2.10) + R (>= 3.5.0) Imports: brms, cmdstanr, - data.table, ggplot2, purrr, stats, @@ -54,7 +53,6 @@ Suggests: patchwork Remotes: stan-dev/cmdstanr, - Rdatatable/data.table, paul-buerkner/brms Config/Needs/website: r-lib/pkgdown, diff --git a/NAMESPACE b/NAMESPACE index 008ba33c8..60cd41c92 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -37,29 +37,10 @@ export(simulate_exponential_cases) export(simulate_gillespie) export(simulate_secondary) export(simulate_uniform_cases) -import(brms) import(cmdstanr) -import(data.table) import(ggplot2) -importFrom(brms,brmsterms) -importFrom(checkmate,assert_data_frame) -importFrom(checkmate,assert_int) -importFrom(checkmate,assert_names) -importFrom(checkmate,assert_numeric) -importFrom(cli,cli_abort) -importFrom(dplyr,all_of) +importFrom(brms,bf) +importFrom(brms,prior) importFrom(dplyr,filter) -importFrom(dplyr,full_join) +importFrom(dplyr,mutate) importFrom(dplyr,select) -importFrom(purrr,map_vec) -importFrom(rstan,lookup) -importFrom(stats,dgamma) -importFrom(stats,dlnorm) -importFrom(stats,pgamma) -importFrom(stats,plnorm) -importFrom(stats,rexp) -importFrom(stats,rgamma) -importFrom(stats,rlnorm) -importFrom(stats,runif) -importFrom(stats,update) -importFrom(utils,capture.output) diff --git a/R/defaults.R b/R/defaults.R index 5ceca0384..5037777b5 100644 --- a/R/defaults.R +++ b/R/defaults.R @@ -3,7 +3,6 @@ #' @inheritParams epidist_validate #' @param ... Additional arguments passed to method. #' @family defaults -#' @importFrom cli cli_abort #' @export epidist_validate.default <- function(data, ...) { cli::cli_abort( @@ -17,7 +16,6 @@ epidist_validate.default <- function(data, ...) { #' @inheritParams epidist_formula #' @param ... Additional arguments passed to method. #' @family defaults -#' @importFrom cli cli_abort #' @export epidist_formula.default <- function(data, ...) { cli::cli_abort( @@ -31,7 +29,6 @@ epidist_formula.default <- function(data, ...) { #' @inheritParams epidist_family #' @param ... Additional arguments passed to method. #' @family defaults -#' @importFrom cli cli_abort #' @export epidist_family.default <- function(data, ...) { cli::cli_abort( @@ -45,7 +42,6 @@ epidist_family.default <- function(data, ...) { #' @inheritParams epidist_stancode #' @param ... Additional arguments passed to method. #' @family defaults -#' @importFrom cli cli_abort #' @export epidist_stancode.default <- function(data, ...) { cli::cli_abort( diff --git a/R/diagnostics.R b/R/diagnostics.R index 8787e80fd..987a4a976 100644 --- a/R/diagnostics.R +++ b/R/diagnostics.R @@ -2,7 +2,7 @@ #' #' This function computes diagnostics to assess the quality of a fitted model. #' When the fitting algorithm used is `"sampling"` (HMC) then the output of -#' `epidist_diagnostics` is a `data.table` containing: +#' `epidist_diagnostics` is a `data.frame` containing: #' * `time`: the total time taken to fit all chains #' * `samples`: the total number of samples across all chains #' * `max_rhat`: the highest value of the Gelman-Rubin statistic @@ -35,19 +35,20 @@ epidist_diagnostics <- function(fit) { } if (fit$algorithm == "sampling") { np <- brms::nuts_params(fit) - divergent_indices <- np$Parameter == "divergent__" - treedepth_indices <- np$Parameter == "treedepth__" - diagnostics <- data.table( - "time" = sum(rstan::get_elapsed_time(fit$fit)), - "samples" = nrow(np) / length(unique(np$Parameter)), - "max_rhat" = round(max(brms::rhat(fit), na.rm = TRUE), 3), - "divergent_transitions" = sum(np[divergent_indices, ]$Value), - "per_divergent_transitions" = mean(np[divergent_indices, ]$Value), - "max_treedepth" = max(np[treedepth_indices, ]$Value) - ) - diagnostics[, no_at_max_treedepth := - sum(np[treedepth_indices, ]$Value == max_treedepth)] - diagnostics[, per_at_max_treedepth := no_at_max_treedepth / samples] + divergent_ind <- np$Parameter == "divergent__" + treedepth_ind <- np$Parameter == "treedepth__" + diagnostics <- dplyr::tibble( + time = sum(rstan::get_elapsed_time(fit$fit)), + samples = nrow(np) / length(unique(np$Parameter)), + max_rhat = round(max(brms::rhat(fit), na.rm = TRUE), 3), + divergent_transitions = sum(np[divergent_ind, ]$Value), + per_divergent_transitions = mean(np[divergent_ind, ]$Value), + max_treedepth = max(np[treedepth_ind, ]$Value) + ) |> + mutate( + no_at_max_treedepth = sum(np[treedepth_ind, ]$Value == max_treedepth), + per_at_max_treedepth = no_at_max_treedepth / samples + ) } else { cli::cli_abort(c( "!" = paste0("Unrecognised algorithm: ", fit$algorithm) diff --git a/R/epidist-package.R b/R/epidist-package.R index 6e903f271..4b493960e 100644 --- a/R/epidist-package.R +++ b/R/epidist-package.R @@ -3,9 +3,9 @@ #' @import ggplot2 #' @import cmdstanr -#' @import brms ## usethis namespace: start -#' @import data.table +#' @importFrom dplyr filter select +#' @importFrom brms bf prior ## usethis namespace: end NULL diff --git a/R/globals.R b/R/globals.R index d7a6f93b9..6dfc6b4a0 100644 --- a/R/globals.R +++ b/R/globals.R @@ -1,57 +1,37 @@ # Generated by roxyglobals: do not edit by hand utils::globalVariables(c( - "no_at_max_treedepth", # "max_treedepth", # - "per_at_max_treedepth", # + "no_at_max_treedepth", # "samples", # - "id", # - "obs_t", # "obs_at", # "ptime_lwr", # - "pwindow", # "stime_lwr", # "ptime_upr", # "stime_upr", # - "woverlap", # - "swindow", # - "delay", # "row_id", # "woverlap", # - "row_id", # - "ptime_daily", # "ptime", # - "ptime_lwr", # - "ptime_upr", # - "stime_daily", # + "ptime_daily", # "stime", # - "stime_lwr", # - "stime_upr", # + "stime_daily", # "delay_daily", # - "delay_lwr", # - "delay_upr", # - "obs_at", # - "obs_at", # "ptime", # - "censored_obs_time", # + "obs_at", # "ptime_lwr", # - "censored", # "stime_upr", # - "censored", # "ptime_upr", # "stime_upr", # + ":=", # "ptime", # - "censored_obs_time", # "ptime_lwr", # "mu", # "sigma", # - "sd", # "mu", # - "sd", # "shape", # - "delay", # - "stime", # + "rlnorm", # "ptime", # + "delay", # "prior_old", # <.replace_prior> "prior_new", # <.replace_prior> "source_new", # <.replace_prior> diff --git a/R/latent_gamma.R b/R/latent_gamma.R index 95ddced02..ec3d2aab0 100644 --- a/R/latent_gamma.R +++ b/R/latent_gamma.R @@ -6,7 +6,6 @@ #' @param prep The result of a call to [brms::posterior_predict()] #' @param ... Additional arguments #' @autoglobal -#' @importFrom stats rgamma #' @keywords internal posterior_predict_latent_gamma <- function(i, prep, ...) { # nolint: object_length_linter mu <- brms::get_dpar(prep, "mu", i = i) @@ -20,8 +19,8 @@ posterior_predict_latent_gamma <- function(i, prep, ...) { # nolint: object_leng d_censored <- obs_t + 1 # while loop to impose the truncation while (d_censored > obs_t) { - p_latent <- runif(1, 0, 1) * pwindow - d_latent <- rgamma(1, shape = shape[s], scale = mu[s] / shape[s]) + p_latent <- stats::runif(1, 0, 1) * pwindow + d_latent <- stats::rgamma(1, shape = shape[s], scale = mu[s] / shape[s]) s_latent <- p_latent + d_latent p_censored <- .floor_mult(p_latent, pwindow) s_censored <- .floor_mult(s_latent, swindow) @@ -53,7 +52,6 @@ posterior_epred_latent_gamma <- function(prep) { # nolint: object_length_linter #' @param i The index of the observation to calculate the log likelihood of #' @param prep The result of a call to [brms::prepare_predictions()] #' @autoglobal -#' @importFrom stats dgamma pgamma #' @keywords internal log_lik_latent_gamma <- function(i, prep) { mu <- brms::get_dpar(prep, "mu", i = i) @@ -63,8 +61,8 @@ log_lik_latent_gamma <- function(i, prep) { pwindow <- prep$data$vreal2[i] swindow <- prep$data$vreal3[i] - swindow_raw <- runif(prep$ndraws) - pwindow_raw <- runif(prep$ndraws) + swindow_raw <- stats::runif(prep$ndraws) + pwindow_raw <- stats::runif(prep$ndraws) swindow <- swindow_raw * swindow @@ -77,7 +75,9 @@ log_lik_latent_gamma <- function(i, prep) { d <- y - pwindow + swindow obs_time <- obs_t - pwindow - lpdf <- dgamma(d, shape = shape, scale = mu / shape, log = TRUE) - lcdf <- pgamma(obs_time, shape = shape, scale = mu / shape, log.p = TRUE) + lpdf <- stats::dgamma(d, shape = shape, scale = mu / shape, log = TRUE) + lcdf <- stats::pgamma( + obs_time, shape = shape, scale = mu / shape, log.p = TRUE + ) return(lpdf - lcdf) } diff --git a/R/latent_individual.R b/R/latent_individual.R index 2947810dd..c69ae6609 100644 --- a/R/latent_individual.R +++ b/R/latent_individual.R @@ -35,31 +35,30 @@ assert_latent_individual_input <- function(data) { #' object, which may be passed to [epidist()] to perform inference for the #' model. #' -#' @param data A `data.frame` or `data.table` containing line list data +#' @param data A `data.frame` containing line list data #' @rdname as_latent_individual #' @method as_latent_individual data.frame #' @family latent_individual -#' @importFrom checkmate assert_data_frame assert_names assert_int -#' assert_numeric #' @autoglobal #' @export as_latent_individual.data.frame <- function(data) { assert_latent_individual_input(data) - class(data) <- c(class(data), "epidist_latent_individual") - data <- data.table::as.data.table(data) - data[, id := seq_len(.N)] - data[, obs_t := obs_at - ptime_lwr] - data[, pwindow := ifelse( - stime_lwr < ptime_upr, ## if overlap - stime_upr - ptime_lwr, - ptime_upr - ptime_lwr - )] - data[, woverlap := as.numeric(stime_lwr < ptime_upr)] - data[, swindow := stime_upr - stime_lwr] - data[, delay := stime_lwr - ptime_lwr] - data[, row_id := seq_len(.N)] + class(data) <- c("epidist_latent_individual", class(data)) + data <- data |> + mutate( + obs_t = obs_at - ptime_lwr, + pwindow = ifelse( + stime_lwr < ptime_upr, + stime_upr - ptime_lwr, + ptime_upr - ptime_lwr + ), + woverlap = as.numeric(stime_lwr < ptime_upr), + swindow = stime_upr - stime_lwr, + delay = stime_lwr - ptime_lwr, + row_id = dplyr::row_number() + ) if (nrow(data) > 1) { - data <- data[, id := as.factor(id)] + data <- mutate(data, row_id = factor(row_id)) } epidist_validate(data) return(data) @@ -72,9 +71,7 @@ as_latent_individual.data.frame <- function(data) { #' `is_latent_individual()` is true, it also checks that `data` is a #' `data.frame` with the correct columns. #' -#' @param data A `data.frame` or `data.table` containing line list data -#' @importFrom checkmate assert_data_frame assert_names assert_int -#' assert_numeric +#' @param data A `data.frame` containing line list data #' @method epidist_validate epidist_latent_individual #' @family latent_individual #' @export @@ -85,23 +82,22 @@ epidist_validate.epidist_latent_individual <- function(data) { names(data), must.include = c("case", "ptime_lwr", "ptime_upr", "stime_lwr", "stime_upr", "obs_at", - "id", "obs_t", "pwindow", "woverlap", + "obs_t", "pwindow", "woverlap", "swindow", "delay", "row_id") ) if (nrow(data) > 1) { - checkmate::assert_factor(data$id) + checkmate::assert_factor(data$row_id) } checkmate::assert_numeric(data$obs_t, lower = 0) checkmate::assert_numeric(data$pwindow, lower = 0) checkmate::assert_numeric(data$woverlap, lower = 0) checkmate::assert_numeric(data$swindow, lower = 0) checkmate::assert_numeric(data$delay, lower = 0) - checkmate::assert_integer(data$row_id, lower = 0) } #' Check if data has the `epidist_latent_individual` class #' -#' @param data A `data.frame` or `data.table` containing line list data +#' @param data A `data.frame` containing line list data #' @family latent_individual #' @export is_latent_individual <- function(data) { @@ -110,11 +106,10 @@ is_latent_individual <- function(data) { #' Check if data has the `epidist_latent_individual` class #' -#' @param data A `data.frame` or `data.table` containing line list data +#' @param data A `data.frame` containing line list data #' @param family Output of a call to `brms::brmsfamily()` #' @param ... ... #' -#' @importFrom rstan lookup #' @method epidist_family epidist_latent_individual #' @family latent_individual #' @export @@ -154,8 +149,6 @@ epidist_family.epidist_latent_individual <- function(data, #' @param ... ... #' @method epidist_formula epidist_latent_individual #' @family latent_individual -#' @importFrom brms brmsterms -#' @importFrom stats update #' @export epidist_formula.epidist_latent_individual <- function(data, family, formula, ...) { @@ -176,7 +169,6 @@ epidist_formula.epidist_latent_individual <- function(data, family, formula, #' @method epidist_stancode epidist_latent_individual #' @family latent_individual #' @autoglobal -#' @importFrom purrr map_vec #' @export epidist_stancode.epidist_latent_individual <- function(data, family = @@ -222,19 +214,19 @@ epidist_stancode.epidist_latent_individual <- function(data, stanvars_data <- brms::stanvar( block = "data", scode = "int wN;", - x = nrow(data[woverlap > 0]), + x = nrow(filter(data, woverlap > 0)), name = "wN" ) + brms::stanvar( block = "data", scode = "array[N - wN] int noverlap;", - x = data[woverlap == 0][, row_id], + x = filter(data, woverlap == 0)$row_id, name = "noverlap" ) + brms::stanvar( block = "data", scode = "array[wN] int woverlap;", - x = data[woverlap > 0][, row_id], + x = filter(data, woverlap > 0)$row_id, name = "woverlap" ) diff --git a/R/latent_lognormal.R b/R/latent_lognormal.R index a8c14a409..b098656cc 100644 --- a/R/latent_lognormal.R +++ b/R/latent_lognormal.R @@ -7,7 +7,6 @@ #' @param prep The result of a call to [brms::posterior_predict()] #' @param ... Additional arguments #' @autoglobal -#' @importFrom stats rlnorm #' @keywords internal posterior_predict_latent_lognormal <- function(i, prep, ...) { # nolint: object_length_linter mu <- brms::get_dpar(prep, "mu", i = i) @@ -21,8 +20,8 @@ posterior_predict_latent_lognormal <- function(i, prep, ...) { # nolint: object_ d_censored <- obs_t + 1 # while loop to impose the truncation while (d_censored > obs_t) { - p_latent <- runif(1, 0, 1) * pwindow - d_latent <- rlnorm(1, meanlog = mu[s], sdlog = sigma[s]) + p_latent <- stats::runif(1, 0, 1) * pwindow + d_latent <- stats::rlnorm(1, meanlog = mu[s], sdlog = sigma[s]) s_latent <- p_latent + d_latent p_censored <- .floor_mult(p_latent, pwindow) s_censored <- .floor_mult(s_latent, swindow) @@ -56,7 +55,6 @@ posterior_epred_latent_lognormal <- function(prep) { # nolint: object_length_lin #' @param i The index of the observation to calculate the log likelihood of #' @param prep The result of a call to [brms::prepare_predictions()] #' @autoglobal -#' @importFrom stats dlnorm plnorm #' @keywords internal log_lik_latent_lognormal <- function(i, prep) { mu <- brms::get_dpar(prep, "mu", i = i) @@ -69,8 +67,8 @@ log_lik_latent_lognormal <- function(i, prep) { # Generates values of the swindow_raw and pwindow_raw, but really these should # be extracted from prep or the fitted raws somehow. See: # https://github.com/epinowcast/epidist/issues/267 - swindow_raw <- runif(prep$ndraws) - pwindow_raw <- runif(prep$ndraws) + swindow_raw <- stats::runif(prep$ndraws) + pwindow_raw <- stats::runif(prep$ndraws) swindow <- swindow_raw * swindow @@ -83,7 +81,7 @@ log_lik_latent_lognormal <- function(i, prep) { d <- y - pwindow + swindow obs_time <- obs_t - pwindow - lpdf <- dlnorm(d, meanlog = mu, sdlog = sigma, log = TRUE) - lcdf <- plnorm(obs_time, meanlog = mu, sdlog = sigma, log.p = TRUE) + lpdf <- stats::dlnorm(d, meanlog = mu, sdlog = sigma, log = TRUE) + lcdf <- stats::plnorm(obs_time, meanlog = mu, sdlog = sigma, log.p = TRUE) return(lpdf - lcdf) } diff --git a/R/observe.R b/R/observe.R index 45fc049ff..7f88e7131 100644 --- a/R/observe.R +++ b/R/observe.R @@ -19,25 +19,19 @@ #' @autoglobal #' @export observe_process <- function(linelist) { - clinelist <- data.table::copy(linelist) - clinelist[, ptime_daily := floor(ptime)] - clinelist[, ptime_lwr := ptime_daily] - clinelist[, ptime_upr := ptime_daily + 1] - # How the second event would be recorded in the data - clinelist[, stime_daily := floor(stime)] - clinelist[, stime_lwr := stime_daily] - clinelist[, stime_upr := stime_daily + 1] - # How would we observe the delay distribution - # previously delay_daily would be the floor(delay) - clinelist[, delay_daily := stime_daily - ptime_daily] - clinelist[, delay_lwr := purrr::map_dbl(delay_daily, ~ max(0, . - 1))] - clinelist[, delay_upr := delay_daily + 1] - # We assume observation time is the ceiling of the maximum delay - clinelist[, obs_at := stime |> - max() |> - ceiling()] - - return(clinelist) + linelist |> + mutate( + ptime_daily = floor(ptime), + ptime_lwr = ptime_daily, + ptime_upr = ptime_daily + 1, + stime_daily = floor(stime), + stime_lwr = stime_daily, + stime_upr = stime_daily + 1, + delay_daily = stime_daily - ptime_daily, + delay_lwr = purrr::map_dbl(delay_daily, ~ max(0, . - 1)), + delay_upr = delay_daily + 1, + obs_at = ceiling(max(stime)) + ) } #' Filter observations based on a observation time of secondary events @@ -48,14 +42,14 @@ observe_process <- function(linelist) { #' @autoglobal #' @export filter_obs_by_obs_time <- function(linelist, obs_time) { - truncated_linelist <- data.table::copy(linelist) - truncated_linelist[, obs_at := obs_time] - truncated_linelist[, obs_time := obs_time - ptime] - truncated_linelist[, censored_obs_time := obs_at - ptime_lwr] - truncated_linelist[, censored := "interval"] - truncated_linelist <- truncated_linelist[stime_upr <= obs_at] - - return(truncated_linelist) + linelist |> + mutate( + obs_at = obs_time, + obs_time = obs_time - ptime, + censored_obs_time = obs_at - ptime_lwr, + censored = "interval" + ) |> + filter(stime_upr <= obs_at) } #' Filter observations based on the observation time of primary events @@ -69,27 +63,26 @@ filter_obs_by_obs_time <- function(linelist, obs_time) { filter_obs_by_ptime <- function(linelist, obs_time, obs_at = c("obs_secondary", "max_secondary")) { obs_at <- match.arg(obs_at) - pfilt_t <- obs_time - truncated_linelist <- data.table::copy(linelist) - truncated_linelist[, censored := "interval"] - truncated_linelist <- truncated_linelist[ptime_upr <= pfilt_t] - + truncated_linelist <- linelist |> + mutate(censored = "interval") |> + filter(ptime_upr <= pfilt_t) if (obs_at == "obs_secondary") { # Update observation time to be the same as the maximum secondary time - truncated_linelist[, obs_at := stime_upr] + truncated_linelist <- mutate(truncated_linelist, obs_at = stime_upr) } else if (obs_at == "max_secondary") { - truncated_linelist[, obs_at := stime_upr |> max() |> ceiling()] + truncated_linelist <- truncated_linelist |> + mutate(obs_at := stime_upr |> max() |> ceiling()) } - - # make observation time as specified - truncated_linelist[, obs_time := obs_at - ptime] - # Assuming truncation at the beginning of the censoring window - truncated_linelist[, censored_obs_time := obs_at - ptime_lwr] - - # set observation time to artifial observation time + # Make observation time as specified + truncated_linelist <- truncated_linelist |> + mutate( + obs_time = obs_at - ptime, + censored_obs_time = obs_at - ptime_lwr + ) + # Set observation time to artificial observation time if needed if (obs_at == "obs_secondary") { - truncated_linelist[, obs_at := pfilt_t] + truncated_linelist <- mutate(truncated_linelist, obs_at = pfilt_t) } return(truncated_linelist) } diff --git a/R/postprocess.R b/R/postprocess.R index 21eb2beaf..e8c506b5d 100644 --- a/R/postprocess.R +++ b/R/postprocess.R @@ -23,11 +23,10 @@ predict_delay_parameters <- function(fit, newdata = NULL, ...) { df[[dpar]] <- as.vector(lp_dpar) } class(df) <- c( - class(df), paste0(sub(".*_", "", fit$family$name), "_samples") + paste0(sub(".*_", "", fit$family$name), "_samples"), class(df) ) - dt <- as.data.table(df) - dt <- add_mean_sd(dt) - return(dt) + df <- add_mean_sd(df) + return(df) } #' @rdname predict_delay_parameters @@ -73,10 +72,10 @@ add_mean_sd.default <- function(data, ...) { #' @autoglobal #' @export add_mean_sd.lognormal_samples <- function(data, ...) { - nat_dt <- data.table::copy(data) - nat_dt <- nat_dt[, mean := exp(mu + sigma ^ 2 / 2)] - nat_dt <- nat_dt[, sd := mean * sqrt(exp(sigma ^ 2) - 1)] - return(nat_dt[]) + mutate(data, + mean = exp(mu + sigma ^ 2 / 2), + sd = mean * sqrt(exp(sigma ^ 2) - 1) + ) } #' Add natural scale mean and standard deviation parameters for a latent gamma @@ -91,8 +90,8 @@ add_mean_sd.lognormal_samples <- function(data, ...) { #' @autoglobal #' @export add_mean_sd.gamma_samples <- function(data, ...) { - nat_dt <- data.table::copy(data) - nat_dt <- nat_dt[, mean := mu] - nat_dt <- nat_dt[, sd := mu / sqrt(shape)] - return(nat_dt[]) + mutate(data, + mean = mu, + sd = mu / sqrt(shape) + ) } diff --git a/R/prior.R b/R/prior.R index 5c5d58f8c..85ee15086 100644 --- a/R/prior.R +++ b/R/prior.R @@ -89,10 +89,10 @@ epidist_family_prior.default <- function(family, formula, ...) { #' @family prior #' @export epidist_family_prior.lognormal <- function(family, formula, ...) { - prior <- brms::prior("normal(1, 1)", class = "Intercept") + prior <- prior("normal(1, 1)", class = "Intercept") if ("sigma" %in% names(formula$pforms)) { # Case with a model on sigma - sigma_prior <- brms::prior( + sigma_prior <- prior( "normal(-0.7, 0.4)", class = "Intercept", dpar = "sigma" ) } else if ("sigma" %in% names(formula$pfix)) { @@ -100,7 +100,7 @@ epidist_family_prior.lognormal <- function(family, formula, ...) { sigma_prior <- NULL } else { # Case with no model on sigma - sigma_prior <- brms::prior( + sigma_prior <- prior( "lognormal(-0.7, 0.4)", class = "sigma", lb = 0, ub = "NA" ) } diff --git a/R/simulate.R b/R/simulate.R index a9c3dc25a..e78c0be53 100644 --- a/R/simulate.R +++ b/R/simulate.R @@ -7,15 +7,14 @@ #' @param t Upper bound of the uniform distribution to generate primary event #' times. #' -#' @return A `data.table` with two columns: `case` (case number) and `ptime` +#' @return A `data.frame` with two columns: `case` (case number) and `ptime` #' (primary event time). #' #' @family simulate -#' @importFrom stats runif #' @export simulate_uniform_cases <- function(sample_size = 1000, t = 60) { - data.table::data.table( - case = 1:sample_size, ptime = runif(sample_size, 0, t) + data.frame( + case = 1:sample_size, ptime = stats::runif(sample_size, 0, t) ) } @@ -31,11 +30,10 @@ simulate_uniform_cases <- function(sample_size = 1000, t = 60) { #' @param seed The random seed to be used in the simulation process. #' @param t Upper bound of the survival time. Defaults to 30. #' -#' @return A `data.table` with two columns: `case` (case number) and `ptime` +#' @return A `data.frame` with two columns: `case` (case number) and `ptime` #' (primary event time). #' #' @family simulate -#' @importFrom stats runif #' @export simulate_exponential_cases <- function(r = 0.2, sample_size = 10000, @@ -44,7 +42,7 @@ simulate_exponential_cases <- function(r = 0.2, if (!missing(seed)) { set.seed(seed) } - quant <- runif(sample_size, 0, 1) + quant <- stats::runif(sample_size, 0, 1) if (r == 0) { ptime <- quant * t @@ -52,10 +50,7 @@ simulate_exponential_cases <- function(r = 0.2, ptime <- log(1 + quant * (exp(r * t) - 1)) / r } - cases <- data.table::data.table( - case = seq_along(ptime), - ptime = ptime - ) + cases <- data.frame(case = seq_along(ptime), ptime = ptime) return(cases) } @@ -72,11 +67,10 @@ simulate_exponential_cases <- function(r = 0.2, #' @param N The total population size. Defaults to 10000. #' @param seed The random seed to be used in the simulation process. #' -#' @return A `data.table` with two columns: `case` (case number) and `ptime` +#' @return A `data.frame` with two columns: `case` (case number) and `ptime` #' (primary event time). #' #' @family simulate -#' @importFrom stats rexp #' @export simulate_gillespie <- function(r = 0.2, gamma = 1 / 7, @@ -97,7 +91,7 @@ simulate_gillespie <- function(r = 0.2, srates <- sum(rates) if (srates > 0) { - deltat <- rexp(1, rate = srates) + deltat <- stats::rexp(1, rate = srates) t <- t + deltat wevent <- sample(seq_along(rates), size = 1, prob = rates) @@ -113,11 +107,7 @@ simulate_gillespie <- function(r = 0.2, } } - cases <- data.table::data.table( - case = seq_along(ptime), - ptime = ptime - ) - + cases <- data.frame(case = seq_along(ptime), ptime = ptime) return(cases) } @@ -131,17 +121,17 @@ simulate_gillespie <- function(r = 0.2, #' @param dist The delay distribution to be used. Defaults to [rlnorm()]. #' @param ... Arguments to be passed to the delay distribution function. #' -#' @return A `data.table` that augments `linelist` with two new columns: `delay` +#' @return A `data.frame` that augments `linelist` with two new columns: `delay` #' (secondary event latency) and `stime` (the time of the secondary event). #' #' @family simulate #' @autoglobal +#' @importFrom dplyr mutate #' @export simulate_secondary <- function(linelist, dist = rlnorm, ...) { - obs <- data.table::copy(linelist) - - obs[, delay := dist(.N, ...)] - obs[, stime := ptime + delay] - - return(obs) + linelist |> + mutate( + delay = dist(dplyr::n(), ...), + stime = ptime + delay + ) } diff --git a/R/utils.R b/R/utils.R index 2c373e3cd..15f92b6ea 100644 --- a/R/utils.R +++ b/R/utils.R @@ -36,7 +36,6 @@ #' #' @param x A number to be rounded down #' @param f A positive number specifying the multiple to be rounded down to -#' @importFrom checkmate assert_numeric #' @keywords internal .floor_mult <- function(x, f = 1) { checkmate::assert_numeric(f, lower = 0) @@ -53,9 +52,6 @@ #' #' @param old_prior One or more prior distributions in the class `brmsprior` #' @param new_prior One or more prior distributions in the class `brmsprior` -#' @importFrom cli cli_abort -#' @importFrom utils capture.output -#' @importFrom dplyr full_join filter select all_of #' @autoglobal #' @keywords internal .replace_prior <- function(old_prior, new_prior) { @@ -70,8 +66,8 @@ if (any(is.na(prior$prior_old))) { missing_prior <- utils::capture.output(print( prior |> - dplyr::filter(is.na(prior_old)) |> - dplyr::select( + filter(is.na(prior_old)) |> + select( prior = prior_new, dplyr::all_of(cols), source = source_new ) )) @@ -83,8 +79,8 @@ } prior <- prior |> - dplyr::filter(!is.na(prior_old), !is.na(prior_new)) |> - dplyr::select(prior = prior_new, dplyr::all_of(cols), source = source_new) + filter(!is.na(prior_old), !is.na(prior_new)) |> + select(prior = prior_new, dplyr::all_of(cols), source = source_new) return(prior) } diff --git a/inst/make_hexsticker.R b/inst/make_hexsticker.R index 24a542c36..745b28c43 100644 --- a/inst/make_hexsticker.R +++ b/inst/make_hexsticker.R @@ -1,6 +1,7 @@ library(hexSticker) library(sysfonts) library(ggplot2) +library(dplyr) # font setup font_add_google("Zilla Slab Highlight", "useme") @@ -8,66 +9,65 @@ font_add_google("Zilla Slab Highlight", "useme") # make standard plot outbreak <- simulate_gillespie(seed = 101) -secondary_dist <- data.table(mu = 1.8, sigma = 0.5) -class(secondary_dist) <- c(class(secondary_dist), "lognormal_samples") +secondary_dist <- data.frame(mu = 1.8, sigma = 0.5) +class(secondary_dist) <- c("lognormal_samples", class(secondary_dist)) secondary_dist <- add_mean_sd(secondary_dist) obs <- outbreak |> simulate_secondary( - meanlog = secondary_dist$meanlog[[1]], - sdlog = secondary_dist$sdlog[[1]] + meanlog = secondary_dist$mu[[1]], + sdlog = secondary_dist$sigma[[1]] ) |> observe_process() truncated_obs <- obs |> - filter_obs_by_obs_time(obs_time = 25) + filter_obs_by_obs_time(obs_time = 25) |> + slice_sample(n = 200, replace = FALSE) -truncated_obs <- truncated_obs[sample(seq_len(.N), 200, replace = FALSE)] +combined_obs <- bind_rows( + truncated_obs, + mutate(obs, obs_at = max(stime_daily)) +) |> + mutate(obs_at = factor(obs_at)) -combined_obs <- combine_obs(truncated_obs, obs) -meanlog <- secondary_dist$meanlog[[1]] -sdlog <- secondary_dist$sdlog[[1]] +meanlog <- secondary_dist$mu[[1]] +sdlog <- secondary_dist$sigma[[1]] -plot <- combined_obs |> +hex_plot <- combined_obs |> ggplot() + - aes(x = delay_daily) + + aes(x = delay_daily, fill = obs_at) + geom_histogram( - aes(y = after_stat(density), fill = obs_at), + aes(y = after_stat(density)), binwidth = 1, position = "dodge" ) + - lims(x = c(0, 18)) - -if (!missing(meanlog) && !missing(sdlog)) { - plot <- plot + - stat_function( - fun = dlnorm, args = c(meanlog, sdlog), n = 100, - col = "#696767b1" - ) -} - -# strip out most of the background -hex_plot <- plot + - scale_fill_brewer(palette = "Blues", direction = 1) + + lims(x = c(0, 18)) + + stat_function( + fun = dlnorm, args = c(meanlog, sdlog), n = 100, + col = "#696767b1" + ) + + scale_fill_brewer(palette = "Set2", direction = 1) + scale_y_continuous(breaks = NULL) + labs(x = "", y = "") + theme_void() + theme_transparent() + - theme(legend.position = "none", - panel.background = element_blank()) + theme( + legend.position = "none", + panel.background = element_blank() + ) -# make and save hexsticker +# Make and save hexsticker sticker( hex_plot, package = "epidist", p_size = 23, p_color = "#646770", - s_x = 1, - s_y = 0.85, - s_width = 1.3, - s_height = 0.75, + p_x = 1.3, + p_y = 1.15, + s_x = 0.85, + s_y = 1, + s_width = 1.2, + s_height = 1.2, h_fill = "#ffffff", h_color = "#646770", - filename = file.path("man", "figures", "logo.png"), - u_color = "#646770", - u_size = 3.5 + filename = file.path("man", "figures", "logo.png") ) diff --git a/inst/vector-real.R b/inst/vector-real.R deleted file mode 100644 index 7aa124792..000000000 --- a/inst/vector-real.R +++ /dev/null @@ -1,10 +0,0 @@ -source("~/Documents/cfa/delays/epidist/tests/testthat/setup.R", echo = TRUE) -prep_obs <- as_latent_individual(sim_obs) -set.seed(1) - -# Fails -fit <- epidist( - data = prep_obs, - formula = brms::bf(mu ~ 1), - seed = 1 -) diff --git a/man/as_latent_individual.Rd b/man/as_latent_individual.Rd index 9e7d35c9f..1c4d097d0 100644 --- a/man/as_latent_individual.Rd +++ b/man/as_latent_individual.Rd @@ -10,7 +10,7 @@ as_latent_individual(data) \method{as_latent_individual}{data.frame}(data) } \arguments{ -\item{data}{A \code{data.frame} or \code{data.table} containing line list data} +\item{data}{A \code{data.frame} containing line list data} } \description{ This function prepares data for use with the latent individual model. It does diff --git a/man/epidist_diagnostics.Rd b/man/epidist_diagnostics.Rd index 6860a82b4..7e9bf9838 100644 --- a/man/epidist_diagnostics.Rd +++ b/man/epidist_diagnostics.Rd @@ -12,7 +12,7 @@ epidist_diagnostics(fit) \description{ This function computes diagnostics to assess the quality of a fitted model. When the fitting algorithm used is \code{"sampling"} (HMC) then the output of -\code{epidist_diagnostics} is a \code{data.table} containing: +\code{epidist_diagnostics} is a \code{data.frame} containing: \itemize{ \item \code{time}: the total time taken to fit all chains \item \code{samples}: the total number of samples across all chains diff --git a/man/epidist_family.epidist_latent_individual.Rd b/man/epidist_family.epidist_latent_individual.Rd index 25e54db58..61b49a7a1 100644 --- a/man/epidist_family.epidist_latent_individual.Rd +++ b/man/epidist_family.epidist_latent_individual.Rd @@ -7,7 +7,7 @@ \method{epidist_family}{epidist_latent_individual}(data, family = "lognormal", ...) } \arguments{ -\item{data}{A \code{data.frame} or \code{data.table} containing line list data} +\item{data}{A \code{data.frame} containing line list data} \item{family}{Output of a call to \code{brms::brmsfamily()}} diff --git a/man/epidist_validate.epidist_latent_individual.Rd b/man/epidist_validate.epidist_latent_individual.Rd index b3c941093..e29b63a74 100644 --- a/man/epidist_validate.epidist_latent_individual.Rd +++ b/man/epidist_validate.epidist_latent_individual.Rd @@ -7,7 +7,7 @@ \method{epidist_validate}{epidist_latent_individual}(data) } \arguments{ -\item{data}{A \code{data.frame} or \code{data.table} containing line list data} +\item{data}{A \code{data.frame} containing line list data} } \description{ This function checks whether the provided \code{data} object is suitable for diff --git a/man/figures/logo.png b/man/figures/logo.png index b2b1e58d9..84a01d17b 100644 Binary files a/man/figures/logo.png and b/man/figures/logo.png differ diff --git a/man/is_latent_individual.Rd b/man/is_latent_individual.Rd index 0d6df1666..0240eadb5 100644 --- a/man/is_latent_individual.Rd +++ b/man/is_latent_individual.Rd @@ -7,7 +7,7 @@ is_latent_individual(data) } \arguments{ -\item{data}{A \code{data.frame} or \code{data.table} containing line list data} +\item{data}{A \code{data.frame} containing line list data} } \description{ Check if data has the \code{epidist_latent_individual} class diff --git a/man/simulate_exponential_cases.Rd b/man/simulate_exponential_cases.Rd index 3a51cd6ac..f5ad55564 100644 --- a/man/simulate_exponential_cases.Rd +++ b/man/simulate_exponential_cases.Rd @@ -16,7 +16,7 @@ simulate_exponential_cases(r = 0.2, sample_size = 10000, seed, t = 30) \item{t}{Upper bound of the survival time. Defaults to 30.} } \value{ -A \code{data.table} with two columns: \code{case} (case number) and \code{ptime} +A \code{data.frame} with two columns: \code{case} (case number) and \code{ptime} (primary event time). } \description{ diff --git a/man/simulate_gillespie.Rd b/man/simulate_gillespie.Rd index c7582c485..b53340d53 100644 --- a/man/simulate_gillespie.Rd +++ b/man/simulate_gillespie.Rd @@ -18,7 +18,7 @@ simulate_gillespie(r = 0.2, gamma = 1/7, I0 = 50, N = 10000, seed) \item{seed}{The random seed to be used in the simulation process.} } \value{ -A \code{data.table} with two columns: \code{case} (case number) and \code{ptime} +A \code{data.frame} with two columns: \code{case} (case number) and \code{ptime} (primary event time). } \description{ diff --git a/man/simulate_secondary.Rd b/man/simulate_secondary.Rd index f066194ff..e0e097972 100644 --- a/man/simulate_secondary.Rd +++ b/man/simulate_secondary.Rd @@ -14,7 +14,7 @@ simulate_secondary(linelist, dist = rlnorm, ...) \item{...}{Arguments to be passed to the delay distribution function.} } \value{ -A \code{data.table} that augments \code{linelist} with two new columns: \code{delay} +A \code{data.frame} that augments \code{linelist} with two new columns: \code{delay} (secondary event latency) and \code{stime} (the time of the secondary event). } \description{ diff --git a/man/simulate_uniform_cases.Rd b/man/simulate_uniform_cases.Rd index ddc6efbfb..8058f0c17 100644 --- a/man/simulate_uniform_cases.Rd +++ b/man/simulate_uniform_cases.Rd @@ -13,7 +13,7 @@ simulate_uniform_cases(sample_size = 1000, t = 60) times.} } \value{ -A \code{data.table} with two columns: \code{case} (case number) and \code{ptime} +A \code{data.frame} with two columns: \code{case} (case number) and \code{ptime} (primary event time). } \description{ diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 46d569963..c3d6abed2 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -15,9 +15,8 @@ sim_obs <- simulate_gillespie() |> sdlog = sdlog ) |> observe_process() |> - filter_obs_by_obs_time(obs_time = obs_time) - -sim_obs <- sim_obs[sample(seq_len(.N), sample_size, replace = FALSE)] + filter_obs_by_obs_time(obs_time = obs_time) |> + dplyr::slice_sample(n = sample_size, replace = FALSE) set.seed(101) @@ -34,10 +33,8 @@ sim_obs_gamma <- simulate_gillespie() |> rate = rate ) |> observe_process() |> - filter_obs_by_obs_time(obs_time = obs_time) - -sim_obs_gamma <- - sim_obs_gamma[sample(seq_len(.N), sample_size, replace = FALSE)] + filter_obs_by_obs_time(obs_time = obs_time) |> + dplyr::slice_sample(n = sample_size, replace = FALSE) # Data with a sex difference @@ -65,10 +62,6 @@ sim_obs_sex_f <- dplyr::filter(sim_obs_sex, sex == 1) |> ) sim_obs_sex <- dplyr::bind_rows(sim_obs_sex_m, sim_obs_sex_f) |> - dplyr::arrange(case) - -sim_obs_sex <- sim_obs_sex |> observe_process() |> - filter_obs_by_obs_time(obs_time = obs_time) - -sim_obs_sex <- sim_obs_sex[sample(seq_len(.N), sample_size, replace = FALSE)] + filter_obs_by_obs_time(obs_time = obs_time) |> + dplyr::slice_sample(n = sample_size, replace = FALSE) diff --git a/tests/testthat/test-latent_gamma.R b/tests/testthat/test-latent_gamma.R index 775026e43..d567de880 100644 --- a/tests/testthat/test-latent_gamma.R +++ b/tests/testthat/test-latent_gamma.R @@ -1,4 +1,5 @@ test_that("posterior_predict_latent_gamma outputs positive integers with length equal to draws", { # nolint: line_length_linter. + skip_on_cran() fit_gamma <- readRDS( system.file("extdata/fit_gamma.rds", package = "epidist") ) @@ -11,6 +12,7 @@ test_that("posterior_predict_latent_gamma outputs positive integers with length }) test_that("posterior_predict_latent_gamma errors for i out of bounds", { # nolint: line_length_linter. + skip_on_cran() fit_gamma <- readRDS( system.file("extdata/fit_gamma.rds", package = "epidist") ) @@ -20,6 +22,7 @@ test_that("posterior_predict_latent_gamma errors for i out of bounds", { # nolin }) test_that("posterior_predict_latent_gamma can generate predictions with no censoring", { # nolint: line_length_linter. + skip_on_cran() fit_gamma <- readRDS( system.file("extdata/fit_gamma.rds", package = "epidist") ) @@ -32,6 +35,7 @@ test_that("posterior_predict_latent_gamma can generate predictions with no censo }) test_that("posterior_predict_latent_gamma predicts delays for which the data is in the 95% credible interval", { # nolint: line_length_linter. + skip_on_cran() fit_gamma <- readRDS( system.file("extdata/fit_gamma.rds", package = "epidist") ) @@ -52,6 +56,7 @@ test_that("posterior_predict_latent_gamma predicts delays for which the data is }) test_that("posterior_epred_latent_gamma creates a array of non-negative numbers with the correct dimensions", { # nolint: line_length_linter. + skip_on_cran() fit_gamma <- readRDS( system.file("extdata/fit_gamma.rds", package = "epidist") ) @@ -64,6 +69,7 @@ test_that("posterior_epred_latent_gamma creates a array of non-negative numbers }) test_that("log_lik_latent_gamma produces a vector with length ndraws of finite non-NA numbers", { # nolint: line_length_linter. + skip_on_cran() fit_gamma <- readRDS( system.file("extdata/fit_gamma.rds", package = "epidist") ) diff --git a/tests/testthat/test-latent_individual.R b/tests/testthat/test-latent_individual.R index 5f7a55da6..910f73c80 100644 --- a/tests/testthat/test-latent_individual.R +++ b/tests/testthat/test-latent_individual.R @@ -6,7 +6,6 @@ as_string_formula <- function(formula) { test_that("as_latent_individual.data.frame with default settings an object with the correct classes", { # nolint: line_length_linter. prep_obs <- as_latent_individual(sim_obs) - expect_s3_class(prep_obs, "data.table") expect_s3_class(prep_obs, "data.frame") expect_s3_class(prep_obs, "epidist_latent_individual") }) diff --git a/tests/testthat/test-latent_lognormal.R b/tests/testthat/test-latent_lognormal.R index 1e2efdb45..cc8e963a9 100644 --- a/tests/testthat/test-latent_lognormal.R +++ b/tests/testthat/test-latent_lognormal.R @@ -1,4 +1,5 @@ test_that("posterior_predict_latent_lognormal outputs positive integers with length equal to draws", { # nolint: line_length_linter. + skip_on_cran() fit <- readRDS( system.file("extdata/fit.rds", package = "epidist") ) @@ -11,6 +12,7 @@ test_that("posterior_predict_latent_lognormal outputs positive integers with len }) test_that("posterior_predict_latent_lognormal errors for i out of bounds", { # nolint: line_length_linter. + skip_on_cran() fit <- readRDS( system.file("extdata/fit.rds", package = "epidist") ) @@ -20,6 +22,7 @@ test_that("posterior_predict_latent_lognormal errors for i out of bounds", { # n }) test_that("posterior_predict_latent_lognormal can generate predictions with no censoring", { # nolint: line_length_linter. + skip_on_cran() fit <- readRDS( system.file("extdata/fit.rds", package = "epidist") ) @@ -32,6 +35,7 @@ test_that("posterior_predict_latent_lognormal can generate predictions with no c }) test_that("posterior_predict_latent_lognormal predicts delays for which the data is in the 95% credible interval", { # nolint: line_length_linter. + skip_on_cran() fit <- readRDS( system.file("extdata/fit.rds", package = "epidist") ) @@ -52,6 +56,7 @@ test_that("posterior_predict_latent_lognormal predicts delays for which the data }) test_that("posterior_epred_latent_lognormal creates a array of non-negative numbers with the correct dimensions", { # nolint: line_length_linter. + skip_on_cran() fit <- readRDS( system.file("extdata/fit.rds", package = "epidist") ) @@ -64,6 +69,7 @@ test_that("posterior_epred_latent_lognormal creates a array of non-negative numb }) test_that("log_lik_latent_lognormal produces a vector with length ndraws of finite non-NA numbers", { # nolint: line_length_linter. + skip_on_cran() fit <- readRDS( system.file("extdata/fit.rds", package = "epidist") ) diff --git a/tests/testthat/test-postprocess.R b/tests/testthat/test-postprocess.R index d17d05a99..5ccc13dfc 100644 --- a/tests/testthat/test-postprocess.R +++ b/tests/testthat/test-postprocess.R @@ -9,7 +9,8 @@ test_that("predict_delay_parameters works with NULL newdata and the latent logno output_dir = fs::dir_create(tempfile()) ) pred <- predict_delay_parameters(fit) - expect_s3_class(pred, "data.table") + expect_s3_class(pred, "lognormal_samples") + expect_s3_class(pred, "data.frame") expect_named(pred, c("draw", "index", "mu", "sigma", "mean", "sd")) expect_true(all(pred$mean > 0)) expect_true(all(pred$sd > 0)) @@ -25,10 +26,12 @@ test_that("predict_delay_parameters accepts newdata arguments and prediction by data = prep_obs_sex, formula = brms::bf(mu ~ 1 + sex, sigma ~ 1 + sex), seed = 1, - silent = 2 + silent = 2, + output_dir = fs::dir_create(tempfile()) ) pred_sex <- predict_delay_parameters(fit_sex, prep_obs_sex) - expect_s3_class(pred_sex, "data.table") + expect_s3_class(pred_sex, "lognormal_samples") + expect_s3_class(pred_sex, "data.frame") expect_named(pred_sex, c("draw", "index", "mu", "sigma", "mean", "sd")) expect_true(all(pred_sex$mean > 0)) expect_true(all(pred_sex$sd > 0)) @@ -36,8 +39,9 @@ test_that("predict_delay_parameters accepts newdata arguments and prediction by expect_equal(length(unique(pred_sex$draw)), summary(fit_sex)$total_ndraws) pred_sex_summary <- pred_sex |> + dplyr::mutate(index = as.factor(index)) |> dplyr::left_join( - dplyr::select(data.frame(prep_obs_sex), index = row_id, sex), + dplyr::select(prep_obs_sex, index = row_id, sex), by = "index" ) |> dplyr::group_by(sex) |> @@ -63,13 +67,12 @@ test_that("predict_delay_parameters accepts newdata arguments and prediction by test_that("add_mean_sd.lognormal_samples works with simulated lognormal distribution parameter data", { # nolint: line_length_linter. set.seed(1) - dt <- data.table( + df <- dplyr::tibble( mu = rnorm(n = 100, mean = 1.8, sd = 0.1), sigma = rnorm(n = 100, mean = 0.5, sd = 0.05) ) - class(dt) <- c(class(dt), "lognormal_samples") - x <- add_mean_sd(dt) - expect_s3_class(x, "data.table") + class(df) <- c("lognormal_samples", class(df)) + x <- add_mean_sd(df) expect_named(x, c("mu", "sigma", "mean", "sd")) expect_true(all(x$mean > 0)) expect_true(all(x$sd > 0)) @@ -77,14 +80,13 @@ test_that("add_mean_sd.lognormal_samples works with simulated lognormal distribu test_that("add_mean_sd.gamma_samples works with simulated gamma distribution parameter data", { # nolint: line_length_linter. set.seed(1) - dt <- data.table( + df <- dplyr::tibble( shape = rnorm(n = 100, mean = 2, sd = 0.1), rate = rnorm(n = 100, mean = 3, sd = 0.2) - ) - dt[, mu := shape / rate] - class(dt) <- c(class(dt), "gamma_samples") - x <- add_mean_sd(dt) - expect_s3_class(x, "data.table") + ) |> + dplyr::mutate(mu = shape / rate) + class(df) <- c("gamma_samples", class(df)) + x <- add_mean_sd(df) expect_named(x, c("shape", "rate", "mu", "mean", "sd")) expect_true(all(x$mean > 0)) expect_true(all(x$sd > 0)) diff --git a/vignettes/approx-inference.Rmd b/vignettes/approx-inference.Rmd index 9700e4b8e..d0cd6a102 100644 --- a/vignettes/approx-inference.Rmd +++ b/vignettes/approx-inference.Rmd @@ -110,16 +110,14 @@ sdlog <- 0.5 obs_time <- 25 sample_size <- 200 -obs_cens_trunc <- simulate_gillespie(seed = 101) |> +obs_cens_trunc_samp <- simulate_gillespie(seed = 101) |> simulate_secondary( meanlog = meanlog, sdlog = sdlog ) |> observe_process() |> - filter_obs_by_obs_time(obs_time = obs_time) - -obs_cens_trunc_samp <- - obs_cens_trunc[sample(seq_len(.N), sample_size, replace = FALSE)] + filter_obs_by_obs_time(obs_time = obs_time) |> + slice_sample(n = sample_size, replace = FALSE) ``` We now prepare the data for fitting with the latent individual model, and perform inference with HMC: diff --git a/vignettes/ebola.Rmd b/vignettes/ebola.Rmd index cde97afd4..917bfe6b6 100644 --- a/vignettes/ebola.Rmd +++ b/vignettes/ebola.Rmd @@ -45,7 +45,6 @@ set.seed(123) library(epidist) library(brms) -library(data.table) library(dplyr) library(purrr) library(ggplot2) @@ -151,8 +150,7 @@ That is, $\mu$ and $\sigma$ such that when $x \sim \mathcal{N}(\mu, \sigma)$ the ## Data preparation To prepare the data, we begin by transforming the date columns to `ptime` and `stime` columns for the times of the primary and secondary events respectively. -Both of these columns are relative to the first date of symptom onset in the data. -We also transform the `data.frame` to a `data.table`: +Both of these columns are relative to the first date of symptom onset in the data: ```{r} sierra_leone_ebola_data <- sierra_leone_ebola_data |> @@ -165,8 +163,7 @@ sierra_leone_ebola_data <- sierra_leone_ebola_data |> ptime = as.numeric(date_of_symptom_onset - min(date_of_symptom_onset)), stime = as.numeric(date_of_sample_tested - min(date_of_symptom_onset)) ) |> - select(case, ptime, stime, age, sex, district) |> - as.data.table() + select(case, ptime, stime, age, sex, district) head(sierra_leone_ebola_data) ``` @@ -182,7 +179,7 @@ For the time being, we filter the data to only complete cases (i.e. rows of the ```{r} n <- nrow(obs_cens) -obs_cens <- obs_cens[complete.cases(obs_cens)] +obs_cens <- obs_cens[complete.cases(obs_cens), ] n_complete <- nrow(obs_cens) ``` @@ -196,8 +193,8 @@ Additionally, to speed up computation, we take a random `r 100 * subsample`% sub (In a real analysis, we'd recommend using all the available data). ```{r} -obs_cens <- - obs_cens[sample(seq_len(.N), n_complete * subsample, replace = FALSE)] +obs_cens <- obs_cens |> + slice_sample(n = round(n_complete * subsample), replace = FALSE) ``` ## Model fitting diff --git a/vignettes/epidist.Rmd b/vignettes/epidist.Rmd index 824d1c74f..837892e08 100644 --- a/vignettes/epidist.Rmd +++ b/vignettes/epidist.Rmd @@ -53,11 +53,10 @@ Then, in Section \@ref(fit), we show how `epidist` can be used to accurately est If you would like more technical details, the `epidist` package implements models following best practices as described in @park2024estimating and @charniga2024best. -To run this vignette yourself, as well as the `epidist` package, you will need the `data.table`^[Note that to work with outputs from `epidist` you do not need to use `data.table`: any tool of your preference is suitable!], `purrr`, `ggplot2`, `gt`, and `dplyr` packages installed. +To run this vignette yourself, along with the `epidist` package, you will need the following packages: ```{r load-requirements} library(epidist) -library(data.table) library(purrr) library(ggplot2) library(gt) @@ -66,7 +65,7 @@ library(dplyr) # Example data {#data} -Data should be formatted as a [`data.table`](https://cran.r-project.org/web/packages/data.table/index.html) with the following columns for use within `epidist`: +Data should be formatted as a `data.frame` with the following columns for use within `epidist`: * `case`: The unique case ID. * `ptime`: The time of the primary event. @@ -84,20 +83,21 @@ outbreak <- simulate_gillespie(seed = 101) (ref:outbreak) Early on in the epidemic, there is a high rate of growth in new cases. As more people are infected, the rate of growth slows. (Only every 50th case is shown to avoid over-plotting.) ```{r outbreak, fig.cap="(ref:outbreak)"} -outbreak[case %% 50 == 0, ] |> +outbreak |> + filter(case %% 50 == 0) |> ggplot(aes(x = ptime, y = case)) + geom_point(col = "#56B4E9") + labs(x = "Primary event time (day)", y = "Case number") + theme_minimal() ``` -`outbreak` is a [`data.table`](https://cran.r-project.org/web/packages/data.table/index.html) with the columns `case` and `ptime`. +`outbreak` is a `data.frame` with the columns `case` and `ptime`. Now, to generate secondary events, we will use a lognormal distribution (Figure \@ref(fig:lognormal)) for the delay between primary and secondary events: ```{r} -secondary_dist <- data.table(mu = 1.8, sigma = 0.5) -class(secondary_dist) <- c(class(secondary_dist), "lognormal_samples") +secondary_dist <- data.frame(mu = 1.8, sigma = 0.5) +class(secondary_dist) <- c("lognormal_samples", class(secondary_dist)) secondary_dist <- add_mean_sd(secondary_dist) ``` @@ -130,7 +130,8 @@ obs <- outbreak |> (ref:delay) Secondary events (in green) occur with a delay drawn from the lognormal distribution (Figure \@ref(fig:lognormal)). As with Figure \@ref(fig:outbreak), to make this figure easier to read, only every 50th case is shown. ```{r delay, fig.cap="(ref:delay)"} -obs[case %% 50 == 0, ] |> +obs |> + filter(case %% 50 == 0) |> ggplot(aes(y = case)) + geom_segment( aes(x = ptime, xend = stime, y = case, yend = case), col = "grey" @@ -141,7 +142,7 @@ obs[case %% 50 == 0, ] |> theme_minimal() ``` -`obs` is now a [`data.table`](https://cran.r-project.org/web/packages/data.table/index.html) object with further columns for `delay` and `stime`. +`obs` is now a `data.frame` with further columns for `delay` and `stime`. The secondary event time is simply the primary event time plus the delay: ```{r} @@ -192,8 +193,8 @@ sample_size <- 200 This sample size corresponds to `r 100 * round(sample_size / nrow(obs_cens_trunc), 3)`% of the data. ```{r} -obs_cens_trunc_samp <- - obs_cens_trunc[sample(seq_len(.N), sample_size, replace = FALSE)] +obs_cens_trunc_samp <- obs_cens_trunc |> + slice_sample(n = sample_size, replace = FALSE) ``` Another issue, which `epidist` currently does not account for, is that sometimes only the secondary event might be observed, and not the primary event. @@ -297,9 +298,9 @@ ggplot() + aes(x = x), fun = dlnorm, args = list( - mu = secondary_dist[["mu"]], - sigma = secondary_dist[["sigma"]] - ), + meanlog = secondary_dist[["mu"]], + sdlog = secondary_dist[["sigma"]] + ) ) + geom_function( data = data.frame(x = c(0, 30)), diff --git a/vignettes/faq.Rmd b/vignettes/faq.Rmd index 3ca8185fb..097d7e78f 100644 --- a/vignettes/faq.Rmd +++ b/vignettes/faq.Rmd @@ -37,16 +37,14 @@ sdlog <- 0.5 obs_time <- 25 sample_size <- 200 -obs_cens_trunc <- simulate_gillespie(seed = 101) |> +obs_cens_trunc_samp <- simulate_gillespie(seed = 101) |> simulate_secondary( meanlog = meanlog, sdlog = sdlog ) |> observe_process() |> - filter_obs_by_obs_time(obs_time = obs_time) - -obs_cens_trunc_samp <- - obs_cens_trunc[sample(seq_len(.N), sample_size, replace = FALSE)] + filter_obs_by_obs_time(obs_time = obs_time) |> + slice_sample(n = sample_size, replace = FALSE) data <- as_latent_individual(obs_cens_trunc_samp) fit <- epidist(