Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change all priors to use <dist_spec> #871

Merged
merged 15 commits into from
Dec 6, 2024
Merged
1 change: 1 addition & 0 deletions .github/workflows/lint-only-changed-files.yaml
Original file line number Diff line number Diff line change
@@ -29,6 +29,7 @@ jobs:
any::gh
any::lintr
any::purrr
progressr

- name: Add lintr options
run: |
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Generated by roxygen2: do not edit by hand

S3method("!=",dist_spec)
S3method("+",dist_spec)
S3method("==",dist_spec)
S3method(c,dist_spec)
S3method(collapse,dist_spec)
S3method(collapse,multi_dist_spec)
5 changes: 3 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -15,8 +15,9 @@
estimate_infections()
```

- A bug was fixed where the initial growth was never estimated (i.e. the prior mean was always zero). By @sbfnk in #853 and reviewed by @seabbs.
- A bug was fixed where an internal function for applying a default cdf cutoff failed due to a difference a vector length issue. By @jamesmbaazam in #858 and reviewed by @sbfnk.
- A bug was fixed where the initial growth was never estimated (i.e. the prior mean was always zero). By @sbfnk in #853 and reviewed by @seabbs.
- A bug was fixed where an internal function for applying a default cdf cutoff failed due to a difference a vector length issue. By @jamesmbaazam in #858 and reviewed by @sbfnk.
- All parameters have been changed to the new parameter interface. By @sbfnk in # and reviewed by @.
sbfnk marked this conversation as resolved.
Show resolved Hide resolved

## Documentation

166 changes: 121 additions & 45 deletions R/create.R
Original file line number Diff line number Diff line change
@@ -319,8 +319,6 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL,

# map settings to underlying gp stan requirements
rt_data <- list(
r_mean = rt$prior$mean,
r_sd = rt$prior$sd,
estimate_r = as.numeric(rt$use_rt),
bp_n = ifelse(rt$use_breakpoints, max(breakpoints) - 1, 0),
breakpoints = breakpoints,
@@ -433,8 +431,6 @@ create_gp_data <- function(gp = gp_opts(), data) {
ls_sdlog = convert_to_logsd(gp$ls_mean, gp$ls_sd),
ls_min = gp$ls_min,
ls_max = gp$ls_max,
alpha_mean = gp$alpha_mean,
alpha_sd = gp$alpha_sd,
gp_type = data.table::fcase(
gp$kernel == "se", 0,
gp$kernel == "periodic", 1,
@@ -472,7 +468,7 @@ create_gp_data <- function(gp = gp_opts(), data) {
#'
#' # Applying a observation scaling to the data
#' create_obs_model(
#' obs_opts(scale = list(mean = 0.4, sd = 0.01)), dates = dates
#' obs_opts(scale = Normal(mean = 0.4, sd = 0.01)), dates = dates
#' )
#'
#' # Apply a custom week week length
@@ -481,13 +477,9 @@ create_gp_data <- function(gp = gp_opts(), data) {
create_obs_model <- function(obs = obs_opts(), dates) {
data <- list(
model_type = as.numeric(obs$family == "negbin"),
phi_mean = obs$phi$mean,
phi_sd = obs$phi$sd,
week_effect = ifelse(obs$week_effect, obs$week_length, 1),
obs_weight = obs$weight,
obs_scale = as.integer(obs$scale$sd > 0 || obs$scale$mean != 1),
obs_scale_mean = obs$scale$mean,
obs_scale_sd = obs$scale$sd,
obs_scale = as.integer(obs$scale != Fixed(1)),
likelihood = as.numeric(obs$likelihood),
return_likelihood = as.numeric(obs$return_likelihood)
)
@@ -589,15 +581,30 @@ create_stan_data <- function(data, seeding_time,
)
)

# parameters
stan_data <- c(
stan_data,
create_stan_params(
alpha = gp$alpha,
R0 = rt$prior,
frac_obs = obs$scale,
rep_phi = obs$phi,
lower_bounds = c(
alpha = 0,
R0 = 0,
frac_obs = 0,
rep_phi = 0
)
)
)

# rescale mean shifted prior for back calculation if observation scaling is
# used
if (stan_data$obs_scale == 1) {
stan_data$shifted_cases <-
stan_data$shifted_cases / stan_data$obs_scale_mean
stan_data$prior_infections <- log(
exp(stan_data$prior_infections) / stan_data$obs_scale_mean
)
}
stan_data$shifted_cases <-
stan_data$shifted_cases / mean(obs$scale)
stan_data$prior_infections <- log(
exp(stan_data$prior_infections) / mean(obs$scale)
)
return(stan_data)
}

@@ -647,34 +654,15 @@ create_initial_conditions <- function(data) {
out$rescaled_rho < data$ls_min, data$ls_min + 0.001,
default = out$rescaled_rho
))

out$alpha <- array(
truncnorm::rtruncnorm(
1, a = 0, mean = data$alpha_mean, sd = data$alpha_sd
)
)
} else {
out$eta <- array(numeric(0))
out$rescaled_rho <- array(numeric(0))
out$alpha <- array(numeric(0))
}
if (data$model_type == 1) {
out$rep_phi <- array(
truncnorm::rtruncnorm(
1,
a = 0, mean = data$phi_mean, sd = data$phi_sd
)
)
}
if (data$estimate_r == 1) {
out$initial_infections <- array(rnorm(1, data$prior_infections, 0.2))
if (data$seeding_time > 1) {
out$initial_growth <- array(rnorm(1, data$prior_growth, 0.02))
}
out$log_R <- array(rnorm(
n = 1, mean = convert_to_logmean(data$r_mean, data$r_sd),
sd = convert_to_logsd(data$r_mean, data$r_sd)
))
}

if (data$bp_n > 0) {
@@ -684,20 +672,17 @@ create_initial_conditions <- function(data) {
out$bp_sd <- array(numeric(0))
out$bp_effects <- array(numeric(0))
}
if (data$obs_scale_sd > 0) {
out$frac_obs <- array(truncnorm::rtruncnorm(1,
a = 0, b = 1,
mean = data$obs_scale_mean,
sd = data$obs_scale_sd
))
} else {
out$frac_obs <- array(numeric(0))
}
if (data$week_effect > 0) {
out$day_of_week_simplex <- array(
rep(1 / data$week_effect, data$week_effect)
)
}
out$params <- array(truncnorm::rtruncnorm(
data$n_params_variable,
a = data$params_lower,
b = data$params_upper,
mean = 0, sd = 1
seabbs marked this conversation as resolved.
Show resolved Hide resolved
))
return(out)
}
return(init_fun)
@@ -877,3 +862,94 @@ create_stan_delays <- function(..., time_points = 1L) {

return(ret)
}

##' Create parameters for stan
##'
##' @param ... Named delay distributions. The names are assigned to IDs
##' @param lower_bounds Named vector of lower bounds for any delay(s). The names
##' have to correspond to the names given to the delay distributions passed.
##' If `NULL` (default) no parameters are given a lower bound.
##' @return A list of variables as expected by the stan model
##' @importFrom data.table fcase
##' @keywords internal
create_stan_params <- function(..., lower_bounds = NULL) {
params <- list(...)

## set IDs of any parameters that is NULL to 0 and remove
null_params <- vapply(params, is.null, logical(1))
null_ids <- rep(0, sum(null_params))
if (length(null_ids) > 0) {
names(null_ids) <- paste(names(null_params)[null_params], "id", sep = "_")
params <- params[!null_params]
}

## initialise variables
params_fixed_lookup <- rep(0L, length(params))
params_variable_lookup <- rep(0L, length(params))

## identify fixed/variable parameters
fixed <- vapply(params, get_distribution, character(1)) == "fixed"
params_fixed_lookup[fixed] <- seq_along(which(fixed))
params_variable_lookup[!fixed] <- seq_along(which(!fixed))

## lower bounds
params_lower <- rep(-Inf, length(params[!fixed]))
names(params_lower) <- names(params[!fixed])
lower_bounds <- lower_bounds[names(params_lower)]
params_lower[names(lower_bounds)] <- lower_bounds

## upper bounds
params_upper <- vapply(params[!fixed], max, numeric(1))

## prior distributions
prior_dist_name <- vapply(params[!fixed], get_distribution, character(1))
prior_dist <- fcase(
prior_dist_name == "lognormal", 0L,
prior_dist_name == "gamma", 1L,
prior_dist_name == "normal", 2L
)
## parameters
prior_dist_params <- lapply(params[!fixed], get_parameters)
prior_dist_params_lengths <- lengths(prior_dist_params)

## check none of the parameters are uncertain
prior_uncertain <- vapply(prior_dist_params, function(x) {
!all(vapply(x, is.numeric, logical(1)))
}, logical(1))
if (any(prior_uncertain)) {
uncertain_priors <- names(params[!fixed])[prior_uncertain] # nolint: object_usage_linter
cli_abort(
c(
"!" = "Parameter prior distribution{?s} for {.var {uncertain_priors}}
cannot have uncertain parameters."
)
)
}

prior_dist_params <- unlist(prior_dist_params)
if (is.null(prior_dist_params)) {
prior_dist_params <- numeric(0)
}

## extract distributions and parameters
ret <- list(
n_params_variable = length(params) - sum(fixed),
n_params_fixed = sum(fixed),
params_lower = array(params_lower),
params_upper = array(params_upper),
params_fixed_lookup = array(params_fixed_lookup),
params_variable_lookup = array(params_variable_lookup),
params_value = array(vapply(
params[fixed], \(x) get_parameters(x)$value, numeric(1)
)),
prior_dist = array(prior_dist),
prior_dist_params_length = sum(prior_dist_params_lengths),
prior_dist_params = array(prior_dist_params)
)
ids <- seq_along(params)
if (length(ids) > 0) {
names(ids) <- paste(names(params), "id", sep = "_")
}
ret <- c(ret, as.list(ids), as.list(null_ids))
return(ret)
}
51 changes: 51 additions & 0 deletions R/dist_spec.R
Original file line number Diff line number Diff line change
@@ -125,6 +125,57 @@ discrete_pmf <- function(distribution =
c(e1, e2)
}

##' Compares two delay distributions
##'
##' @param e1 The first delay distribution (of type <dist_spec>) to
##' combine.
##'
##' @param e2 The second delay distribution (of type <dist_spec>) to
##' combine.
##' @method == dist_spec
##' @return TRUE or FALSE
##' @export
##' @examples
##' Fixed(1) == Normal(1, 0.5)
## nolint start: cyclocomp_linter
`==.dist_spec` <- function(e1, e2) {
## both must have same number of distributions
if (ndist(e1) != ndist(e2)) return(FALSE)
## loop over constituent distributions
for (i in seq_len(ndist(e1))) {
## distributions need to be the same
if (get_distribution(e1, i) != get_distribution(e2, i)) return(FALSE)
if (get_distribution(e1, i) == "nonparametric") {
## if nonparametric then PMFs need to be the same
if (!identical(get_pmf(e1, i), get_pmf(e2, i))) return(FALSE)
} else {
## if parametric then all parameters need to be the same
params1 <- get_parameters(e1, i)
params2 <- get_parameters(e2, i)
for (param in names(params1)) {
## all parameters must be the same type
if ((is(params1[[param]], "dist_spec") &&
is(params2[[param]], "dist_spec")) ||
(is.numeric(params1[[param]]) && is.numeric(params2[[param]]))) {
## if parameters are the same type they need to be same value
if (!(params1[[param]] == params2[[param]])) return(FALSE)
} else {
return(FALSE)
}
}
}
}
return(TRUE)
}
## nolint end: cyclocomp_linter

##' @rdname equals-.dist_spec
##' @method != dist_spec
##' @export
`!=.dist_spec` <- function(e1, e2) {
!(e1 == e2)
}

#' Combines multiple delay distributions for further processing
#'
#' @description `r lifecycle::badge("experimental")`
2 changes: 1 addition & 1 deletion R/epinow.R
Original file line number Diff line number Diff line change
@@ -65,7 +65,7 @@
#' out <- epinow(
#' data = reported_cases,
#' generation_time = gt_opts(generation_time),
#' rt = rt_opts(prior = list(mean = 2, sd = 0.1)),
#' rt = rt_opts(prior = Normal(mean = 2, sd = 0.1)),
#' delays = delay_opts(incubation_period + reporting_delay)
#' )
#' # summary of the latest estimates
2 changes: 1 addition & 1 deletion R/estimate_infections.R
Original file line number Diff line number Diff line change
@@ -107,7 +107,7 @@
#' def <- estimate_infections(reported_cases,
#' generation_time = gt_opts(generation_time),
#' delays = delay_opts(incubation_period + reporting_delay),
#' rt = rt_opts(prior = list(mean = 2, sd = 0.1))
#' rt = rt_opts(prior = Normal(mean = 2, sd = 0.1))
#' )
#' # real time estimates
#' summary(def)
15 changes: 12 additions & 3 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
@@ -101,7 +101,7 @@
#' # fit model to example data specifying a weak prior for fraction reported
#' # with a secondary case
#' inc <- estimate_secondary(cases[1:60],
#' obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE)
#' obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE)
#' )
#' plot(inc, primary = TRUE)
#'
@@ -129,7 +129,7 @@
#' secondary = secondary_opts(type = "prevalence"),
#' obs = obs_opts(
#' week_effect = FALSE,
#' scale = list(mean = 0.4, sd = 0.1)
#' scale = Normal(mean = 0.4, sd = 0.1)
#' )
#' )
#' plot(prev, primary = TRUE)
@@ -250,6 +250,15 @@ estimate_secondary <- function(data,
# observation model data
stan_data <- c(stan_data, create_obs_model(obs, dates = reports$date))

stan_data <- c(stan_data, create_stan_params(
frac_obs = obs$scale,
rep_phi = obs$phi,
lower_bounds = c(
frac_obs = 0,
rep_phi = 0
)
))

# update data to use specified priors rather than defaults
stan_data <- update_secondary_args(stan_data,
priors = priors, verbose = verbose
@@ -674,7 +683,7 @@ forecast_secondary <- function(estimate,

# allocate empty parameters
data <- allocate_empty(
data, c("frac_obs", "delay_params", "rep_phi"),
data, c("params", "delay_params"),
n = data$n
)
data$all_dates <- as.integer(all_dates)
Loading