Skip to content

Commit

Permalink
implement dist_spec interface (#363)
Browse files Browse the repository at this point in the history
* implement `dist_spec` interface

* Add skipping of stan tests in the expected places

* fix checks

* fix epinow example

* update syntax in more places

* update another example

* move parenthesis to the right place

* update uncertainty in estimate_infections example

* fix use of generation_time_opts as resource to call get_generation_time()

* break line to make R CMD CHECK happy

* fix uses of `trunc_opts`

* fix updating of `cur_len` in ragged convolution

* remove bounds on mean parameters

* use dist_spec syntax in `estimate_delays`

* add print function for `dist_spec`

* add names to printing if given

* clarify printouts

* reduce unnecessary function calls

* Revert "reduce unnecessary function calls"

This reverts commit 9037e21.

* fix typo

* add default option for generation time

* simplify delay inits

* update pmf doc

* dist -> distribution

* fix max of np dist logic

* simplify pmf truncation syntax

* fix typos and use `is`

* extract function for stan code conversion

* fix variable name

* fix function name

* do truncnorm with appropriate lengths

* fix initial condition sampling

* update `to_stan` documentation

* fix typo

* stan model with unified delays

* update R access to unified dist interface

* update tests

* ensure arrays are arrays

* simplify stan seq (and avoid conflict with R)

* fix test

* fix simulation models

* fix final tests

* update usage of c -> +

* Automatic readme update

* update examples/doc and re-doc

* linting

* update docs

* final requested lint

* update return type of bootstrapped_dist_fit

* redoc

* update estimate_delay to reflect changes

* dot product for all convolutions

* report gt mean and var

* bug fix in calculation of max delays

* Automatic readme update

* update tests

* clean whitespace

* reduce number of calculations by precomputing len

* optional head/tail

* Revert "optional head/tail"

This reverts commit 8c59db1.

* don't convolve first pmf

* reduce vector copying

* fix reversing

* fix printing of combined distributions

* add exampples, export, and add basic dist plotting

* Automatic readme update

* add some tests for dist_spec

* add tests for +.dist_spec

* add tests for mean.dist_spec

* add some basic additional tests and docs

* linting

* fix linting

* export c

* fix plotting to work with c() method for dist_spec

* more linting fixes

* remove extract line in generation_time.stan

* add a check in convolve_rev_pmf when len >= xlen + ylen and update tests

* be more efficient when calc discrete pmfs

* catch missing indexes in omf calc

* clarify comments

* add tolerance for +.dist_spec

* don't load testthat

* trigger benchmarking

* remove benchmark trigger

* linting

* Automatic readme update

* trigger benchmarking

* remove benchmark trigger

* refine tolerance checks for convolution

* fix example

* add an internal function

* trigger benchmarking

* benchmarking off

* add back in missing tolerance docs

* fix edge case check for length 1 pmfs

* whitespace linting

* test more carefully

* use commas like a smart boy

* crank that adapt delta handle

* Update R/create.R

Co-authored-by: Sebastian Funk <[email protected]>

* Update R/get.R

Co-authored-by: Sebastian Funk <[email protected]>

* Update R/opts.R

Co-authored-by: Sebastian Funk <[email protected]>

* Update R/dist.R

Co-authored-by: Sebastian Funk <[email protected]>

* fixed @internal and brackets + fcase

* don't export c.dist_spec

* drop c() examle from plot

---------

Co-authored-by: Sam Abbott <[email protected]>
Co-authored-by: GitHub Action <[email protected]>
  • Loading branch information
3 people authored Jun 14, 2023
1 parent 1b25f69 commit d0bea23
Show file tree
Hide file tree
Showing 91 changed files with 2,042 additions and 1,137 deletions.
1 change: 1 addition & 0 deletions .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,3 +24,4 @@
^\.devcontainer$
^CRAN-SUBMISSION$
^touchstone$
^\.benchmark$
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,4 @@ inst/include/*.o
src

.DS_Store
.vscode
8 changes: 8 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,9 +1,13 @@
# Generated by roxygen2: do not edit by hand

S3method("+",dist_spec)
S3method(mean,dist_spec)
S3method(plot,dist_spec)
S3method(plot,epinow)
S3method(plot,estimate_infections)
S3method(plot,estimate_secondary)
S3method(plot,estimate_truncation)
S3method(print,dist_spec)
S3method(summary,epinow)
S3method(summary,estimate_infections)
export(R_to_growth)
Expand Down Expand Up @@ -145,6 +149,7 @@ importFrom(ggplot2,geom_line)
importFrom(ggplot2,geom_linerange)
importFrom(ggplot2,geom_point)
importFrom(ggplot2,geom_ribbon)
importFrom(ggplot2,geom_step)
importFrom(ggplot2,geom_vline)
importFrom(ggplot2,ggplot)
importFrom(ggplot2,ggplot_build)
Expand All @@ -157,6 +162,7 @@ importFrom(ggplot2,scale_x_date)
importFrom(ggplot2,scale_y_continuous)
importFrom(ggplot2,theme)
importFrom(ggplot2,theme_bw)
importFrom(ggplot2,vars)
importFrom(lifecycle,deprecate_soft)
importFrom(lifecycle,deprecate_warn)
importFrom(lubridate,days)
Expand Down Expand Up @@ -188,6 +194,7 @@ importFrom(rstan,summary)
importFrom(rstan,vb)
importFrom(runner,mean_run)
importFrom(scales,comma)
importFrom(stats,convolve)
importFrom(stats,glm)
importFrom(stats,lm)
importFrom(stats,median)
Expand All @@ -207,5 +214,6 @@ importFrom(stats,sd)
importFrom(stats,var)
importFrom(truncnorm,rtruncnorm)
importFrom(utils,capture.output)
importFrom(utils,head)
importFrom(utils,tail)
useDynLib(EpiNow2, .registration=TRUE)
22 changes: 12 additions & 10 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,24 @@

This release is in development. For a stable release install 1.3.5 from CRAN.

## Breaking changes

- The external distribution interface has been updated to use the `dist_spec()` function. This comes with a range of benefits, including optimising model fitting when static delays are used (by convolving when first defined vs in stan), easy printing (using `print()`), and easy plotting (using `plot()`). It also makes it possible to use all supported distributions everywhere (i.e, as a generation time or reporting delay). However, this update will break most users code as the interface has changed. See the documentation for `dist_spec()` for more details. By @sbfnk in #363 and reviewed by @seabbs.

## Package

* Model description has been expanded to include more detail.
* Moved to a GitHub Action to only lint changed files.
* Linted the package with a wider range of default linters.
* Added a GitHub Action to build the README when it is altered.
* Added handling of edge case where we sample from the negative binomial with
mean close or equal to 0. By @sbfnk in #366.
* Replaced use of nested `ifelse()` and `data.table::fifelse()` in the
code base with `data.table::fcase()`. By @jamesmbaazam in #383 and reviewed by @seabbs.
* Reviewed the example in `calc_backcalc_data()` to call `calc_backcalc_data()`
instead of `create_gp_data()`. By @jamesmbaazam in #388 and reviewed by @seabbs.
* Model description has been expanded to include more detail. By @sbfnk in #373 and reviewed by @seabbs.
* Moved to a GitHub Action to only lint changed files. By @seabbs in #378.
* Linted the package with a wider range of default linters. By @seabbs in #378.
* Added a GitHub Action to build the README when it is altered. By @seabbs.
* Added handling of edge case where we sample from the negative binomial with mean close or equal to 0. By @sbfnk in #366 and reviewed by @seabbs.
* Replaced use of nested `ifelse()` and `data.table::fifelse()` in the code base with `data.table::fcase()`. By @jamesmbaazam in #383 and reviewed by @seabbs.
* Reviewed the example in `calc_backcalc_data()` to call `calc_backcalc_data()` instead of `create_gp_data()`. By @jamesmbaazam in #388 and reviewed by @seabbs.
* Improved compilation times by reducing the number of distinct stan models and deprecated `tune_inv_gamma()`. By @sbfnk in #394 and reviewed by @seabbs.
* Changed touchstone settings so that benchmarks are only performed if the stan model is changed. By @sbfnk in #400 and reviewed by @seabbs.
* [pak](https://pak.r-lib.org/) is now suggested for installing the developmental version of the package. By @jamesmbaazam in #407 and reviewed by @seabbs. This has been successfully tested on MacOS Ventura, Ubuntu 20.04, and Windows 10. Users are advised to use `remotes::install_github("epiforecasts/EpiNow2")` if `pak` fails and if both fail, raise an issue.
* `dist_fit()`'s `samples` argument now takes a default value of 1000 instead of NULL. If a supplied `samples` is less than 1000, it is changed to 1000 and a warning is thrown to indicate the change. By @jamesmbazam in #389 and reviewed by @seabbs.
* The internal distribution interface has been streamlined to reduce code duplication. By @sbfnk in #363 and reviewed by @seabbs.

# EpiNow2 1.3.5

Expand Down
154 changes: 82 additions & 72 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -414,10 +414,8 @@ create_obs_model <- function(obs = obs_opts(), dates) {
#'
#' @param shifted_cases A dataframe of delay shifted cases
#'
#' @param truncation `r lifecycle::badge("experimental")` A list of options as
#' generated by `trunc_opts()` defining the truncation of observed data.
#' Defaults to `trunc_opts()`. See `estimate_truncation()` for an approach to
#' estimating truncation from data.
#' @param seeding_time Integer; seeding time, usually obtained using
#' `get_seeding_time()`
#'
#' @inheritParams create_gp_data
#' @inheritParams create_obs_model
Expand All @@ -430,34 +428,20 @@ create_obs_model <- function(obs = obs_opts(), dates) {
#' @author Sam Abbott
#' @author Sebastian Funk
#' @export
create_stan_data <- function(reported_cases, generation_time,
rt, gp, obs, delays, horizon,
backcalc, shifted_cases,
truncation) {
## make sure we have at least gt_max seeding time
delays$seeding_time <- max(delays$seeding_time, generation_time$max)
create_stan_data <- function(reported_cases, seeding_time,
rt, gp, obs, horizon,
backcalc, shifted_cases) {

## for backwards compatibility call generation_time_opts internally
if (is.list(generation_time) &&
all(c("mean", "mean_sd", "sd", "sd_sd") %in% names(generation_time))) {
generation_time <- do.call(generation_time_opts, generation_time)
}

cases <- reported_cases[(delays$seeding_time + 1):(.N - horizon)]$confirm
cases <- reported_cases[(seeding_time + 1):(.N - horizon)]$confirm

data <- list(
cases = cases,
shifted_cases = shifted_cases,
t = length(reported_cases$date),
horizon = horizon,
burn_in = 0
burn_in = 0,
seeding_time = seeding_time
)
# add gt data
data <- c(data, generation_time)
# add delay data
data <- c(data, delays)
# add truncation data
data <- c(data, truncation)
# add Rt data
data <- c(
data,
Expand All @@ -476,10 +460,6 @@ create_stan_data <- function(reported_cases, generation_time,
is.na(data$prior_infections) || is.null(data$prior_infections),
0, data$prior_infections
)
if (is.null(data$gt_weight)) {
## default: weigh by number of data points
data$gt_weight <- data$t - data$seeding_time - data$horizon
}
if (data$seeding_time > 1) {
safe_lm <- purrr::safely(stats::lm)
data$prior_growth <- safe_lm(log(confirm) ~ t, data = first_week)[[1]]
Expand Down Expand Up @@ -532,37 +512,20 @@ create_stan_data <- function(reported_cases, generation_time,
create_initial_conditions <- function(data) {
init_fun <- function() {
out <- list()
if (data$n_uncertain_mean_delays > 0) {
out$delay_mean <- array(purrr::map2_dbl(
data$delay_mean_mean[data$uncertain_mean_delays],
data$delay_mean_sd[data$uncertain_mean_delays] * 0.1,
~ rnorm(1, mean = .x, sd = .y)
if (data$delay_n_p > 0) {
out$delay_mean <- array(truncnorm::rtruncnorm(
n = data$delay_n_p, a = 0,
mean = data$delay_mean_mean, sd = data$delay_mean_sd * 0.1
))
}
if (data$n_uncertain_sd_delays > 0) {
out$delay_sd <- array(purrr::map2_dbl(
data$delay_sd_mean[data$uncertain_sd_delays],
data$delay_sd_sd[data$uncertain_sd_delays] * 0.1,
~ rnorm(1, mean = .x, sd = .y)
out$delay_sd <- array(truncnorm::rtruncnorm(
n = data$delay_n_p, a = 0,
mean = data$delay_sd_mean, sd = data$delay_sd_sd * 0.1
))
} else {
out$delay_mean <- array(numeric(0))
out$delay_sd <- array(numeric(0))
}
if (data$truncation > 0) {
if (data$trunc_mean_sd > 0) {
out$truncation_mean <- array(rnorm(1,
mean = data$trunc_mean_mean,
sd = data$trunc_mean_sd * 0.1
))
}
if (data$trunc_sd_sd > 0) {
out$truncation_sd <- array(
truncnorm::rtruncnorm(1,
a = 0,
mean = data$trunc_sd_mean,
sd = data$trunc_sd_sd * 0.1
)
)
}
}

if (data$fixed == 0) {
out$eta <- array(rnorm(data$M, mean = 0, sd = 0.1))
out$rho <- array(rlnorm(1,
Expand All @@ -579,6 +542,10 @@ create_initial_conditions <- function(data) {
out$alpha <- array(
truncnorm::rtruncnorm(1, a = 0, mean = 0, sd = data$alpha_sd)
)
} else {
out$eta <- array(numeric(0))
out$rho <- array(numeric(0))
out$alpha <- array(numeric(0))
}
if (data$model_type == 1) {
out$rep_phi <- array(
Expand All @@ -597,30 +564,23 @@ create_initial_conditions <- function(data) {
n = 1, mean = convert_to_logmean(data$r_mean, data$r_sd),
sd = convert_to_logsd(data$r_mean, data$r_sd) * 0.1
))
if (data$gt_mean_sd > 0) {
out$gt_mean <- array(truncnorm::rtruncnorm(1,
a = 0, mean = data$gt_mean_mean,
sd = data$gt_mean_sd * 0.1
))
}
if (data$gt_sd_sd > 0) {
out$gt_sd <- array(truncnorm::rtruncnorm(1,
a = 0, mean = data$gt_sd_mean,
sd = data$gt_sd_sd * 0.1
))
}
}

if (data$bp_n > 0) {
out$bp_sd <- array(truncnorm::rtruncnorm(1, a = 0, mean = 0, sd = 0.1))
out$bp_effects <- array(rnorm(data$bp_n, 0, 0.1))
}
if (data$bp_n > 0) {
out$bp_sd <- array(truncnorm::rtruncnorm(1, a = 0, mean = 0, sd = 0.1))
out$bp_effects <- array(rnorm(data$bp_n, 0, 0.1))
} else {
out$bp_sd <- array(numeric(0))
out$bp_effects <- array(numeric(0))
}
if (data$obs_scale == 1) {
out$frac_obs <- array(truncnorm::rtruncnorm(1,
a = 0, b = 1,
mean = data$obs_scale_mean,
sd = data$obs_scale_sd * 0.1
))
} else {
out$frac_obs <- array(numeric(0))
}
if (data$week_effect > 0) {
out$day_of_week_simplex <- array(
Expand Down Expand Up @@ -675,3 +635,53 @@ create_stan_args <- function(stan = stan_opts(),
args$return_fit <- NULL
return(args)
}

##' Create delay variables for stan
##'
##' @param ... Named delay distributions specified using `dist_spec()`.
##' The names are assigned to IDs
##' @param ot Integer, number of observations (needed if weighing any priors)
##' with the number of observations
##' @return A list of variables as expected by the stan model
##' @importFrom purrr transpose map
##' @author Sebastian Funk
create_stan_delays <- function(..., ot) {
dot_args <- list(...)
## combine delays
combined_delays <- unclass(c(...))
## number of different non-empty types
type_n <- unlist(purrr::transpose(dot_args)$n)
## assign ID values to each type
ids <- rep(0L, length(type_n))
ids[type_n > 0] <- seq_len(sum(type_n > 0))
names(ids) <- paste(names(type_n), "id", sep = "_")

## start consructing stan object
ret <- unclass(combined_delays)
## construct additional variables
ret <- c(ret, list(
types = sum(type_n > 0),
types_p = array(1L - combined_delays$fixed)
))
## delay identifiers
ret$types_id <- integer(0)
ret$types_id[ret$types_p == 1] <- seq_len(ret$n_p)
ret$types_id[ret$types_p == 0] <- seq_len(ret$n_np)
ret$types_id <- array(ret$types_id)
## map delays to identifiers
ret$types_groups <- array(c(0, cumsum(unname(type_n[type_n > 0]))) + 1)
## map pmfs
ret$np_pmf_groups <- array(c(0, cumsum(combined_delays$np_pmf_length)) + 1)
## assign prior weights
if (any(ret$weight == 0)) {
ret$weight[ret$weight == 0] <- ot
}
## remove auxiliary variables
ret$fixed <- NULL
ret$np_pmf_length <- NULL

names(ret) <- paste("delay", names(ret), sep = "_")
ret <- c(ret, ids)

return(ret)
}
Loading

0 comments on commit d0bea23

Please sign in to comment.