Skip to content

Commit

Permalink
Move package depends to suggest (#798)
Browse files Browse the repository at this point in the history
* Make progressr a suggested package

* Make `future.apply` and `future` suggested

* add news item

* qualify package name

* progressr for batch simulation

* fix package name

* use do.call for future_lapply call

* fix argument name

* add reviewer

Co-authored-by: Sam Abbott <[email protected]>

* update variable name(s)

---------

Co-authored-by: Sam Abbott <[email protected]>
  • Loading branch information
sbfnk and seabbs authored Sep 30, 2024
1 parent de1e374 commit 2fb6f1a
Show file tree
Hide file tree
Showing 15 changed files with 116 additions and 65 deletions.
6 changes: 3 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -103,15 +103,12 @@ Imports:
cli,
data.table,
futile.logger (>= 1.4),
future,
future.apply,
ggplot2,
lifecycle,
lubridate,
methods,
patchwork,
posterior,
progressr,
purrr,
R.utils (>= 2.0.0),
Rcpp (>= 0.12.0),
Expand All @@ -126,9 +123,12 @@ Imports:
Suggests:
cmdstanr,
covr,
future,
future.apply,
here,
knitr,
precommit,
progressr,
rmarkdown,
spelling,
testthat,
Expand Down
6 changes: 0 additions & 6 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,6 @@ importFrom(futile.logger,flog.threshold)
importFrom(futile.logger,flog.trace)
importFrom(futile.logger,flog.warn)
importFrom(futile.logger,ftry)
importFrom(future,availableCores)
importFrom(future,plan)
importFrom(future,tweak)
importFrom(future.apply,future_lapply)
importFrom(ggplot2,.data)
importFrom(ggplot2,aes)
importFrom(ggplot2,coord_cartesian)
Expand Down Expand Up @@ -205,8 +201,6 @@ importFrom(lubridate,days)
importFrom(lubridate,wday)
importFrom(patchwork,plot_layout)
importFrom(posterior,mcse_mean)
importFrom(progressr,progressor)
importFrom(progressr,with_progress)
importFrom(purrr,compact)
importFrom(purrr,flatten)
importFrom(purrr,keep)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ A release that introduces model improvements to the Gaussian Process models, alo
- All functions now use the `{cli}` R package to signal errors, warnings, and messages. By @jamesmbaazam in #762 and reviewed by @seabbs.
- `fix_dist()` has been renamed to `fix_parameters()` because it removes the uncertainty in a distribution's parameters. By @sbfnk in #733 and reviewed by @jamesmbaazam.
- `plot.dist_spec` now uses color instead of line types to display pmfs vs cmfs. By @jamesmbaazam in #788 and reviewed by @sbfnk.
- The use of the `{progressr}` package for displaying progress bars is now optional, as is the use of `{future}` and `{future.apply}` for parallelisation. By @sbfnk in #798 and reviewed by @seabbs.

## Bug fixes

Expand Down
1 change: 0 additions & 1 deletion R/deprecated.R
Original file line number Diff line number Diff line change
Expand Up @@ -221,7 +221,6 @@ lognorm_dist_def <- function(mean, mean_sd,
#' @inheritParams estimate_infections
#' @inheritParams adjust_infection_to_report
#' @importFrom data.table data.table rbindlist
#' @importFrom future.apply future_lapply
report_cases <- function(case_estimates,
case_forecast = NULL,
delays,
Expand Down
18 changes: 10 additions & 8 deletions R/estimate_delay.R
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,6 @@ dist_fit <- function(values = NULL, samples = 1000, cores = 1,
#'
#' @return A `<dist_spec>` object summarising the bootstrapped distribution
#' @importFrom purrr list_transpose
#' @importFrom future.apply future_lapply
#' @importFrom rstan extract
#' @importFrom data.table data.table rbindlist
#' @importFrom cli cli_abort col_blue
Expand Down Expand Up @@ -199,7 +198,7 @@ bootstrapped_dist_fit <- function(values, dist = "lognormal",
dist_samples <- get_single_dist(values, samples = samples)
} else {
## Fit each sub sample
dist_samples <- future.apply::future_lapply(1:bootstraps,
dist_samples <- lapply_func(1:bootstraps,
function(boot) {
get_single_dist(
sample(values,
Expand All @@ -209,12 +208,15 @@ bootstrapped_dist_fit <- function(values, dist = "lognormal",
samples = ceiling(samples / bootstraps)
)
},
future.scheduling = Inf,
future.globals = c(
"values", "bootstraps", "samples",
"bootstrap_samples", "get_single_dist"
),
future.packages = "data.table", future.seed = TRUE
future.opts = list(
future.scheduling = Inf,
future.globals = c(
"values", "bootstraps", "samples",
"bootstrap_samples", "get_single_dist"
),
future.packages = "data.table",
future.seed = TRUE
)
)


Expand Down
6 changes: 2 additions & 4 deletions R/fit.R
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#'
#' @importFrom futile.logger flog.debug flog.info flog.error
#' @importFrom R.utils withTimeout
#' @importFrom future.apply future_lapply
#' @importFrom purrr compact
#' @importFrom rstan sflist2stanfit sampling
#' @importFrom rlang abort cnd_muffle
Expand Down Expand Up @@ -103,12 +102,11 @@ fit_model_with_nuts <- function(args, future = FALSE, max_execution_time = Inf,
chains <- args$chains
args$chains <- 1
args$cores <- 1
fits <- future.apply::future_lapply(1:chains,
fits <- lapply_func(1:chains,
fit_chain,
stan_args = args,
max_time = max_execution_time,
catch = TRUE,
future.seed = TRUE
catch = TRUE
)
if (stuck_chains > 0) {
fits[1:stuck_chains] <- NULL
Expand Down
30 changes: 18 additions & 12 deletions R/regional_epinow.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
#'
#' Regions can be estimated in parallel using the `{future}` package (see
#' [setup_future()]). The progress of producing estimates across multiple
#' regions is tracked using the `{progressr}` package. Modify this behaviour
#' regions can be tracked using the `{progressr}` package. Modify this behaviour
#' using [progressr::handlers()] and enable it in batch by setting
#' `R_PROGRESSR_ENABLE=TRUE` as an environment variable.
#'
Expand Down Expand Up @@ -54,13 +54,11 @@
#' @export
#' @seealso [epinow()] [estimate_infections()] [setup_future()]
#' [regional_summary()]
#' @importFrom future.apply future_lapply
#' @importFrom data.table as.data.table setDT copy setorder
#' @importFrom purrr safely map compact keep
#' @importFrom futile.logger flog.info flog.warn flog.trace
#' @importFrom R.utils withTimeout
#' @importFrom rlang cnd_muffle
#' @importFrom progressr with_progress progressor
#' @examples
#' \donttest{
#' # set number of cores to use
Expand Down Expand Up @@ -161,9 +159,8 @@ regional_epinow <- function(data,
" function"
)

progressr::with_progress({
progress_fn <- progressr::progressor(along = regions)
regional_out <- future.apply::future_lapply(regions, run_region,
run_regions <- function(progress_fn = NULL) {
lapply_func(regions, run_region,
generation_time = generation_time,
delays = delays,
truncation = truncation,
Expand All @@ -186,10 +183,19 @@ regional_epinow <- function(data,
progress_fn = progress_fn,
verbose = verbose,
...,
future.scheduling = Inf,
future.seed = TRUE
future.opts = list(
future.scheduling = Inf,
future.seed = TRUE
)
)
})
}
if (requireNamespace("progressr", quietly = TRUE)) {
progressr::with_progress({
regional_out <- run_regions(progressr::progressor(along = regions))
})
} else {
regional_out <- run_regions()
}

out <- process_regions(regional_out, regions)
regional_out <- out$all
Expand Down Expand Up @@ -313,7 +319,7 @@ clean_regions <- function(data, non_zero_points) {
#'
#' @param target_region Character string indicating the region being evaluated
#' @param progress_fn Function as returned by [progressr::progressor()]. Allows
#' the use of a progress bar.
#' the use of a progress bar. If NULL (default), no progress bar is used.
#'
#' @param complete_logger Character string indicating the logger to output
#' the completion of estimation to.
Expand Down Expand Up @@ -341,7 +347,7 @@ run_region <- function(target_region,
output,
complete_logger,
verbose,
progress_fn,
progress_fn = NULL,
...) {
futile.logger::flog.info("Initialising estimates for: %s", target_region,
name = "EpiNow2.epinow"
Expand Down Expand Up @@ -390,7 +396,7 @@ run_region <- function(target_region,
complete_logger
)

if (!missing(progress_fn)) {
if (!is.null(progress_fn)) {
progress_fn(sprintf("Region: %s", target_region))
}
return(out)
Expand Down
14 changes: 12 additions & 2 deletions R/setup.R
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ setup_default_logging <- function(logs = tempdir(check = TRUE),
#' A utility function that aims to streamline the set up
#' of the required future backend with sensible defaults for most users of
#' [regional_epinow()]. More advanced users are recommended to setup their own
#' `{future}` backend based on their available resources.
#' `{future}` backend based on their available resources. Running this requires
#' the `{future}` package to be installed.
#'
#' @param strategies A vector length 1 to 2 of strategies to pass to
#' [future::plan()]. Nesting of parallelisation is from the top level down.
Expand All @@ -136,7 +137,6 @@ setup_default_logging <- function(logs = tempdir(check = TRUE),
#'
#' @inheritParams regional_epinow
#' @importFrom futile.logger flog.error flog.info flog.debug
#' @importFrom future availableCores plan tweak
#' @importFrom cli cli_abort
#' @export
#' @return Numeric number of cores to use per worker. If greater than 1 pass to
Expand All @@ -145,6 +145,16 @@ setup_default_logging <- function(logs = tempdir(check = TRUE),
setup_future <- function(data,
strategies = c("multisession", "multisession"),
min_cores_per_worker = 4) {
if (!requireNamespace("future", quietly = TRUE)) {
futile.logger::flog.error(
"The future package is required for parallelisation"
)
cli_abort(
c(
"!" = "The future package is required for parallelisation."
)
)
}
if (length(strategies) > 2 || length(strategies) == 0) {
futile.logger::flog.error("1 or 2 strategies should be used")
cli_abort(
Expand Down
48 changes: 25 additions & 23 deletions R/simulate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,11 @@ simulate_infections <- function(estimates, R, initial_infections,
#' simulate. May decrease run times due to reduced IO costs but this is still
#' being evaluated. If set to NULL then all simulations are done at once.
#'
#' @param verbose Logical defaults to [interactive()]. Should a progress bar
#' (from `progressr`) be shown.
#' @param verbose Logical defaults to [interactive()]. If the `progressr`
#' package is available, a progress bar will be shown.
#' @inheritParams stan_opts
#' @importFrom rstan extract sampling
#' @importFrom purrr list_transpose map safely compact
#' @importFrom future.apply future_lapply
#' @importFrom progressr with_progress progressor
#' @importFrom data.table rbindlist as.data.table
#' @importFrom lubridate days
#' @importFrom checkmate assert_class assert_names test_numeric test_data_frame
Expand Down Expand Up @@ -472,39 +470,43 @@ forecast_infections <- function(estimates,

safe_batch <- safely(batch_simulate)

if (backend == "cmdstanr") {
lapply_func <- lapply ## future_lapply can't handle cmdstanr
} else {
lapply_func <- function(...) future_lapply(future.seed = TRUE, ...)
}

## simulate in batches
with_progress({
if (verbose) {
p <- progressor(along = batches)
}
out <- lapply_func(batches,
process_batches <- function(p = NULL) {
lapply_func(batches,
function(batch) {
if (verbose) {
if (!is.null(p)) {
p()
}
safe_batch(
estimates, draws, model,
shift, dates, batch[[1]],
batch[[2]]
)[[1]]
}
},
future.opts = list(
future.seed = TRUE
),
backend = backend
)
})
}

## simulate in batches
if (verbose && requireNamespace("progressr", quietly = TRUE)) {
p <- progressr::progressor(along = batches)
progressr::with_progress({
regional_out <- process_batches(p)
})
} else {
regional_out <- process_batches()
}

## join batches
out <- compact(out)
out <- list_transpose(out, simplify = FALSE)
out <- map(out, rbindlist)
regional_out <- compact(regional_out)
regional_out <- list_transpose(regional_out, simplify = FALSE)
regional_out <- map(regional_out, rbindlist)

## format output
format_out <- format_fit(
posterior_samples = out,
posterior_samples = regional_out,
horizon = estimates$args$horizon,
shift = shift,
burn_in = 0,
Expand Down
15 changes: 15 additions & 0 deletions R/utilities.R
Original file line number Diff line number Diff line change
Expand Up @@ -402,6 +402,21 @@ set_dt_single_thread <- function() {
)
}

#' Choose a parallel or sequential apply function
#'
#' Internal function that chooses an appropriate "apply"-type function (either
#' [lapply()] or [future.apply::future_lapply()])
#' @return A function that can be used to apply a function to a list
#' @keywords internal
#' @inheritParams stan_opts
lapply_func <- function(..., backend = "rstan", future.opts = list()) {
if (requireNamespace("future.apply", quietly = TRUE) && backend == "rstan") {
do.call(future.apply::future_lapply, c(list(...), future.opts))
} else {
lapply(...)
}
}

#' @importFrom stats glm median na.omit pexp pgamma plnorm quasipoisson rexp
#' @importFrom stats rlnorm rnorm rpois runif sd var rgamma pnorm
globalVariables(
Expand Down
4 changes: 2 additions & 2 deletions man/forecast_infections.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

23 changes: 23 additions & 0 deletions man/lapply_func.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/regional_epinow.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 2fb6f1a

Please sign in to comment.