Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change all priors to use <dist_spec> #871

Merged
merged 15 commits into from
Dec 6, 2024
1 change: 1 addition & 0 deletions .github/workflows/lint-only-changed-files.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ jobs:
any::gh
any::lintr
any::purrr
progressr
- name: Add lintr options
run: |
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# Generated by roxygen2: do not edit by hand

S3method("!=",dist_spec)
S3method("+",dist_spec)
S3method("==",dist_spec)
S3method(c,dist_spec)
S3method(collapse,dist_spec)
S3method(collapse,multi_dist_spec)
Expand Down
5 changes: 3 additions & 2 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,9 @@
estimate_infections()
```

- A bug was fixed where the initial growth was never estimated (i.e. the prior mean was always zero). By @sbfnk in #853 and reviewed by @seabbs.
- A bug was fixed where an internal function for applying a default cdf cutoff failed due to a difference a vector length issue. By @jamesmbaazam in #858 and reviewed by @sbfnk.
- A bug was fixed where the initial growth was never estimated (i.e. the prior mean was always zero). By @sbfnk in #853 and reviewed by @seabbs.
- A bug was fixed where an internal function for applying a default cdf cutoff failed due to a difference a vector length issue. By @jamesmbaazam in #858 and reviewed by @sbfnk.
- All parameters have been changed to the new parameter interface. By @sbfnk in # and reviewed by @.
sbfnk marked this conversation as resolved.
Show resolved Hide resolved

## Documentation

Expand Down
166 changes: 121 additions & 45 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -319,8 +319,6 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL,

# map settings to underlying gp stan requirements
rt_data <- list(
r_mean = rt$prior$mean,
r_sd = rt$prior$sd,
estimate_r = as.numeric(rt$use_rt),
bp_n = ifelse(rt$use_breakpoints, max(breakpoints) - 1, 0),
breakpoints = breakpoints,
Expand Down Expand Up @@ -433,8 +431,6 @@ create_gp_data <- function(gp = gp_opts(), data) {
ls_sdlog = convert_to_logsd(gp$ls_mean, gp$ls_sd),
ls_min = gp$ls_min,
ls_max = gp$ls_max,
alpha_mean = gp$alpha_mean,
alpha_sd = gp$alpha_sd,
gp_type = data.table::fcase(
gp$kernel == "se", 0,
gp$kernel == "periodic", 1,
Expand Down Expand Up @@ -472,7 +468,7 @@ create_gp_data <- function(gp = gp_opts(), data) {
#'
#' # Applying a observation scaling to the data
#' create_obs_model(
#' obs_opts(scale = list(mean = 0.4, sd = 0.01)), dates = dates
#' obs_opts(scale = Normal(mean = 0.4, sd = 0.01)), dates = dates
#' )
#'
#' # Apply a custom week week length
Expand All @@ -481,13 +477,9 @@ create_gp_data <- function(gp = gp_opts(), data) {
create_obs_model <- function(obs = obs_opts(), dates) {
data <- list(
model_type = as.numeric(obs$family == "negbin"),
phi_mean = obs$phi$mean,
phi_sd = obs$phi$sd,
week_effect = ifelse(obs$week_effect, obs$week_length, 1),
obs_weight = obs$weight,
obs_scale = as.integer(obs$scale$sd > 0 || obs$scale$mean != 1),
obs_scale_mean = obs$scale$mean,
obs_scale_sd = obs$scale$sd,
obs_scale = as.integer(obs$scale != Fixed(1)),
likelihood = as.numeric(obs$likelihood),
return_likelihood = as.numeric(obs$return_likelihood)
)
Expand Down Expand Up @@ -589,15 +581,30 @@ create_stan_data <- function(data, seeding_time,
)
)

# parameters
stan_data <- c(
stan_data,
create_stan_params(
alpha = gp$alpha,
R0 = rt$prior,
frac_obs = obs$scale,
rep_phi = obs$phi,
lower_bounds = c(
alpha = 0,
R0 = 0,
frac_obs = 0,
rep_phi = 0
)
)
)

# rescale mean shifted prior for back calculation if observation scaling is
# used
if (stan_data$obs_scale == 1) {
stan_data$shifted_cases <-
stan_data$shifted_cases / stan_data$obs_scale_mean
stan_data$prior_infections <- log(
exp(stan_data$prior_infections) / stan_data$obs_scale_mean
)
}
stan_data$shifted_cases <-
stan_data$shifted_cases / mean(obs$scale)
stan_data$prior_infections <- log(
exp(stan_data$prior_infections) / mean(obs$scale)
)
return(stan_data)
}

Expand Down Expand Up @@ -647,34 +654,15 @@ create_initial_conditions <- function(data) {
out$rescaled_rho < data$ls_min, data$ls_min + 0.001,
default = out$rescaled_rho
))

out$alpha <- array(
truncnorm::rtruncnorm(
1, a = 0, mean = data$alpha_mean, sd = data$alpha_sd
)
)
} else {
out$eta <- array(numeric(0))
out$rescaled_rho <- array(numeric(0))
out$alpha <- array(numeric(0))
}
if (data$model_type == 1) {
out$rep_phi <- array(
truncnorm::rtruncnorm(
1,
a = 0, mean = data$phi_mean, sd = data$phi_sd
)
)
}
if (data$estimate_r == 1) {
out$initial_infections <- array(rnorm(1, data$prior_infections, 0.2))
if (data$seeding_time > 1) {
out$initial_growth <- array(rnorm(1, data$prior_growth, 0.02))
}
out$log_R <- array(rnorm(
n = 1, mean = convert_to_logmean(data$r_mean, data$r_sd),
sd = convert_to_logsd(data$r_mean, data$r_sd)
))
}

if (data$bp_n > 0) {
Expand All @@ -684,20 +672,17 @@ create_initial_conditions <- function(data) {
out$bp_sd <- array(numeric(0))
out$bp_effects <- array(numeric(0))
}
if (data$obs_scale_sd > 0) {
out$frac_obs <- array(truncnorm::rtruncnorm(1,
a = 0, b = 1,
mean = data$obs_scale_mean,
sd = data$obs_scale_sd
))
} else {
out$frac_obs <- array(numeric(0))
}
if (data$week_effect > 0) {
out$day_of_week_simplex <- array(
rep(1 / data$week_effect, data$week_effect)
)
}
out$params <- array(truncnorm::rtruncnorm(
data$n_params_variable,
a = data$params_lower,
b = data$params_upper,
mean = 0, sd = 1
seabbs marked this conversation as resolved.
Show resolved Hide resolved
))
return(out)
}
return(init_fun)
Expand Down Expand Up @@ -877,3 +862,94 @@ create_stan_delays <- function(..., time_points = 1L) {

return(ret)
}

##' Create parameters for stan
##'
##' @param ... Named delay distributions. The names are assigned to IDs
##' @param lower_bounds Named vector of lower bounds for any delay(s). The names
##' have to correspond to the names given to the delay distributions passed.
##' If `NULL` (default) no parameters are given a lower bound.
##' @return A list of variables as expected by the stan model
##' @importFrom data.table fcase
##' @keywords internal
create_stan_params <- function(..., lower_bounds = NULL) {
params <- list(...)

## set IDs of any parameters that is NULL to 0 and remove
null_params <- vapply(params, is.null, logical(1))
null_ids <- rep(0, sum(null_params))
if (length(null_ids) > 0) {
names(null_ids) <- paste(names(null_params)[null_params], "id", sep = "_")
params <- params[!null_params]
}

## initialise variables
params_fixed_lookup <- rep(0L, length(params))
params_variable_lookup <- rep(0L, length(params))

## identify fixed/variable parameters
fixed <- vapply(params, get_distribution, character(1)) == "fixed"
params_fixed_lookup[fixed] <- seq_along(which(fixed))
params_variable_lookup[!fixed] <- seq_along(which(!fixed))

## lower bounds
params_lower <- rep(-Inf, length(params[!fixed]))
names(params_lower) <- names(params[!fixed])
lower_bounds <- lower_bounds[names(params_lower)]
params_lower[names(lower_bounds)] <- lower_bounds

## upper bounds
params_upper <- vapply(params[!fixed], max, numeric(1))

## prior distributions
prior_dist_name <- vapply(params[!fixed], get_distribution, character(1))
prior_dist <- fcase(
prior_dist_name == "lognormal", 0L,
prior_dist_name == "gamma", 1L,
prior_dist_name == "normal", 2L
)
## parameters
prior_dist_params <- lapply(params[!fixed], get_parameters)
prior_dist_params_lengths <- lengths(prior_dist_params)

## check none of the parameters are uncertain
prior_uncertain <- vapply(prior_dist_params, function(x) {
!all(vapply(x, is.numeric, logical(1)))
}, logical(1))
if (any(prior_uncertain)) {
uncertain_priors <- names(params[!fixed])[prior_uncertain] # nolint: object_usage_linter
cli_abort(
c(
"!" = "Parameter prior distribution{?s} for {.var {uncertain_priors}}
cannot have uncertain parameters."
)
)
}

prior_dist_params <- unlist(prior_dist_params)
if (is.null(prior_dist_params)) {
prior_dist_params <- numeric(0)
}

## extract distributions and parameters
ret <- list(
n_params_variable = length(params) - sum(fixed),
n_params_fixed = sum(fixed),
params_lower = array(params_lower),
params_upper = array(params_upper),
params_fixed_lookup = array(params_fixed_lookup),
params_variable_lookup = array(params_variable_lookup),
params_value = array(vapply(
params[fixed], \(x) get_parameters(x)$value, numeric(1)
)),
prior_dist = array(prior_dist),
prior_dist_params_length = sum(prior_dist_params_lengths),
prior_dist_params = array(prior_dist_params)
)
ids <- seq_along(params)
if (length(ids) > 0) {
names(ids) <- paste(names(params), "id", sep = "_")
}
ret <- c(ret, as.list(ids), as.list(null_ids))
return(ret)
}
51 changes: 51 additions & 0 deletions R/dist_spec.R
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,57 @@ discrete_pmf <- function(distribution =
c(e1, e2)
}

##' Compares two delay distributions
##'
##' @param e1 The first delay distribution (of type <dist_spec>) to
##' combine.
##'
##' @param e2 The second delay distribution (of type <dist_spec>) to
##' combine.
##' @method == dist_spec
##' @return TRUE or FALSE
##' @export
##' @examples
##' Fixed(1) == Normal(1, 0.5)
## nolint start: cyclocomp_linter
`==.dist_spec` <- function(e1, e2) {
## both must have same number of distributions
if (ndist(e1) != ndist(e2)) return(FALSE)
## loop over constituent distributions
for (i in seq_len(ndist(e1))) {
## distributions need to be the same
if (get_distribution(e1, i) != get_distribution(e2, i)) return(FALSE)
if (get_distribution(e1, i) == "nonparametric") {
## if nonparametric then PMFs need to be the same
if (!identical(get_pmf(e1, i), get_pmf(e2, i))) return(FALSE)
} else {
## if parametric then all parameters need to be the same
params1 <- get_parameters(e1, i)
params2 <- get_parameters(e2, i)
for (param in names(params1)) {
## all parameters must be the same type
if ((is(params1[[param]], "dist_spec") &&
is(params2[[param]], "dist_spec")) ||
(is.numeric(params1[[param]]) && is.numeric(params2[[param]]))) {
## if parameters are the same type they need to be same value
if (!(params1[[param]] == params2[[param]])) return(FALSE)
} else {
return(FALSE)
}
}
}
}
return(TRUE)
}
## nolint end: cyclocomp_linter

##' @rdname equals-.dist_spec
##' @method != dist_spec
##' @export
`!=.dist_spec` <- function(e1, e2) {
!(e1 == e2)
}

#' Combines multiple delay distributions for further processing
#'
#' @description `r lifecycle::badge("experimental")`
Expand Down
2 changes: 1 addition & 1 deletion R/epinow.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@
#' out <- epinow(
#' data = reported_cases,
#' generation_time = gt_opts(generation_time),
#' rt = rt_opts(prior = list(mean = 2, sd = 0.1)),
#' rt = rt_opts(prior = Normal(mean = 2, sd = 0.1)),
#' delays = delay_opts(incubation_period + reporting_delay)
#' )
#' # summary of the latest estimates
Expand Down
2 changes: 1 addition & 1 deletion R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@
#' def <- estimate_infections(reported_cases,
#' generation_time = gt_opts(generation_time),
#' delays = delay_opts(incubation_period + reporting_delay),
#' rt = rt_opts(prior = list(mean = 2, sd = 0.1))
#' rt = rt_opts(prior = Normal(mean = 2, sd = 0.1))
#' )
#' # real time estimates
#' summary(def)
Expand Down
15 changes: 12 additions & 3 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@
#' # fit model to example data specifying a weak prior for fraction reported
#' # with a secondary case
#' inc <- estimate_secondary(cases[1:60],
#' obs = obs_opts(scale = list(mean = 0.2, sd = 0.2), week_effect = FALSE)
#' obs = obs_opts(scale = Normal(mean = 0.2, sd = 0.2), week_effect = FALSE)
#' )
#' plot(inc, primary = TRUE)
#'
Expand Down Expand Up @@ -129,7 +129,7 @@
#' secondary = secondary_opts(type = "prevalence"),
#' obs = obs_opts(
#' week_effect = FALSE,
#' scale = list(mean = 0.4, sd = 0.1)
#' scale = Normal(mean = 0.4, sd = 0.1)
#' )
#' )
#' plot(prev, primary = TRUE)
Expand Down Expand Up @@ -250,6 +250,15 @@ estimate_secondary <- function(data,
# observation model data
stan_data <- c(stan_data, create_obs_model(obs, dates = reports$date))

stan_data <- c(stan_data, create_stan_params(
frac_obs = obs$scale,
rep_phi = obs$phi,
lower_bounds = c(
frac_obs = 0,
rep_phi = 0
)
))

# update data to use specified priors rather than defaults
stan_data <- update_secondary_args(stan_data,
priors = priors, verbose = verbose
Expand Down Expand Up @@ -674,7 +683,7 @@ forecast_secondary <- function(estimate,

# allocate empty parameters
data <- allocate_empty(
data, c("frac_obs", "delay_params", "rep_phi"),
data, c("params", "delay_params"),
n = data$n
)
data$all_dates <- as.integer(all_dates)
Expand Down
Loading
Loading