Skip to content

Commit

Permalink
Merge c606708 into 5b712af
Browse files Browse the repository at this point in the history
  • Loading branch information
sbfnk authored Oct 25, 2023
2 parents 5b712af + c606708 commit 4470973
Show file tree
Hide file tree
Showing 36 changed files with 229 additions and 292 deletions.
3 changes: 1 addition & 2 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -654,8 +654,7 @@ create_stan_delays <- function(..., weight = 1) {
## number of different non-empty types
type_n <- unlist(purrr::transpose(dot_args)$n)
## assign ID values to each type
ids <- rep(0L, length(type_n))
ids[type_n > 0] <- seq_len(sum(type_n > 0))
ids <- seq_along(type_n)
names(ids) <- paste(names(type_n), "id", sep = "_")

## start consructing stan object
Expand Down
24 changes: 8 additions & 16 deletions R/dist.R
Original file line number Diff line number Diff line change
Expand Up @@ -978,13 +978,7 @@ dist_spec <- function(mean, sd = 0, mean_sd = 0, sd_sd = 0,
)
if (length(pmf) == 0) {
if (missing(mean)) { ## empty
ret <- c(ret, list(
n = 0,
n_p = 0,
n_np = 0,
np_pmf = numeric(0),
fixed = integer(0)
))
pmf <- 1
} else { ## parametric fixed
if (sd == 0) { ## delta
pmf <- c(rep(0, mean), 1)
Expand Down Expand Up @@ -1013,15 +1007,13 @@ dist_spec <- function(mean, sd = 0, mean_sd = 0, sd_sd = 0,
}
pmf <- pmf / sum(pmf)
}
if (length(pmf) > 0) {
ret <- c(ret, list(
n = 1,
n_p = 0,
n_np = 1,
np_pmf = pmf,
fixed = 1L
))
}
ret <- c(ret, list(
n = 1,
n_p = 0,
n_np = 1,
np_pmf = pmf,
fixed = 1L
))
} else {
ret <- list(
mean_mean = mean,
Expand Down
4 changes: 3 additions & 1 deletion R/estimate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
#' @importFrom lubridate days
#' @importFrom purrr transpose
#' @importFrom futile.logger flog.threshold flog.warn flog.debug
#' @importFrom rstan extract
#' @examples
#' \donttest{
#' # set number of cores to use
Expand Down Expand Up @@ -234,7 +235,8 @@ estimate_infections <- function(reported_cases,
fit <- fit_model_with_vb(args, id = id)
}
# Extract parameters of interest from the fit
out <- extract_parameter_samples(fit, data,
samples <- extract(fit)
out <- extract_parameter_samples(samples, data,
reported_inf_dates = reported_cases$date,
reported_dates = reported_cases$date[-(1:data$seeding_time)]
)
Expand Down
14 changes: 9 additions & 5 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
Expand Up @@ -628,6 +628,7 @@ forecast_secondary <- function(estimate,
)
# extract data from stanfit
data <- estimate$data
data$primary <- NULL

# combined primary from data and input primary
primary_fit <- estimate$predictions[,
Expand All @@ -643,19 +644,20 @@ forecast_secondary <- function(estimate,
data.table::setorderv(primary_fit, c("sample", "date"))

# update data with primary samples and day of week
data$primary <- t(
data$primary_samples <- t(
matrix(primary_fit$value, ncol = length(unique(primary_fit$sample)))
)
data$day_of_week <- add_day_of_week(
unique(primary_fit$date), data$week_effect
)
data$n <- nrow(data$primary)
data$t <- ncol(data$primary)
data$n <- nrow(data$primary_samples)
data$t <- ncol(data$primary_samples)
data$h <- nrow(primary[sample == min(sample)])

# extract samples for posterior of estimates
posterior_samples <- sample(seq_len(data$n), data$n, replace = TRUE) # nolint
draws <- purrr::map(draws, ~ as.matrix(.[posterior_samples, ]))
names(draws) <- paste0(names(draws), "_samples")
# combine with data
data <- c(data, draws)

Expand All @@ -666,8 +668,10 @@ forecast_secondary <- function(estimate,

# allocate empty parameters
data <- allocate_empty(
data, c("frac_obs", "delay_mean", "delay_sd", "rep_phi"),
n = data$n
data, c(
"frac_obs_samples", "delay_mean_samples", "delay_sd_samples",
"rep_phi_samples"
), n = data$n
)
data$all_dates <- as.integer(all_dates)
## simulate
Expand Down
9 changes: 3 additions & 6 deletions R/extract.R
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ extract_static_parameter <- function(param, samples) {
#' Extracts a custom set of parameters from a stan object and adds
#' stratification and dates where appropriate.
#'
#' @param stan_fit A fit Stan model as returned by `rstan:sampling`.
#' @param samples A list of arrays containing samples extracted from a stan
#' model using \code{\link[rsta]{extract}}
#'
#' @param data A list of the data supplied to the `rstan::sampling` call.
#'
Expand All @@ -77,14 +78,10 @@ extract_static_parameter <- function(param, samples) {
#'
#' @return A list of dataframes each containing the posterior of a parameter
#' @author Sam Abbott
#' @importFrom rstan extract
#' @importFrom data.table data.table
extract_parameter_samples <- function(stan_fit, data, reported_dates,
extract_parameter_samples <- function(samples, data, reported_dates,
reported_inf_dates,
drop_length_1 = FALSE, merge = FALSE) {
# extract sample from stan object
samples <- rstan::extract(stan_fit)

## drop initial length 1 dimensions if requested
if (drop_length_1) {
samples <- lapply(samples, function(x) {
Expand Down
17 changes: 14 additions & 3 deletions R/simulate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
#' @importFrom progressr with_progress progressor
#' @importFrom data.table rbindlist as.data.table
#' @importFrom lubridate days
#' @importFrom rstan extract
#' @return A list of output as returned by [estimate_infections()] but based on
#' results from the specified scenario rather than fitting.
#' @export
Expand Down Expand Up @@ -186,13 +187,17 @@ simulate_infections <- function(estimates,
shift, dates, nstart, nend) {
# extract batch samples from draws
draws <- map(draws, ~ as.matrix(.[nstart:nend, ]))
names(draws) <- paste0(names(draws), "_samples")

## prepare data for stan command
data <- c(list(n = dim(draws$R)[1]), draws, estimates$args)
data <- c(list(n = dim(draws$R_samples)[1]), draws, estimates$args)

## allocate empty parameters
data <- allocate_empty(
data, c("frac_obs", "delay_mean", "delay_sd", "rep_phi"),
data, c(
"frac_obs_samples", "delay_mean_samples", "delay_sd_samples",
"rep_phi_samples"
),
n = data$n
)

Expand All @@ -203,8 +208,12 @@ simulate_infections <- function(estimates,
algorithm = "Fixed_param",
refresh = 0
)
## extract sample from stan object
samples <- extract(sims)
names(samples) <- sub("^sim_", "", names(samples))
names(data) <- sub("_samples$", "", names(data))

out <- extract_parameter_samples(sims, data,
out <- extract_parameter_samples(samples, data,
reported_inf_dates = dates,
reported_dates = dates[-(1:shift)],
drop_length_1 = TRUE, merge = TRUE
Expand Down Expand Up @@ -251,6 +260,8 @@ simulate_infections <- function(estimates,
out <- transpose(out)
out <- map(out, rbindlist)

names(out) <- sub("^sim_", "", names(out))

## format output
format_out <- format_fit(
posterior_samples = out,
Expand Down
3 changes: 3 additions & 0 deletions inst/stan/chunks/R_to_growth.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
real gt_mean = rev_pmf_mean(gt_rev_pmf, 1);
real gt_var = rev_pmf_var(gt_rev_pmf, 1, gt_mean);
r = R_to_growth(R, gt_mean, gt_var);
4 changes: 4 additions & 0 deletions inst/stan/chunks/calculate_secondary.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
reports = calculate_secondary(
primary, obs, frac_obs, delay_rev_pmf, cumulative,
historic, primary_hist_additive, current, primary_current_additive, predict
);
1 change: 1 addition & 0 deletions inst/stan/chunks/convolve_to_report.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
reports = convolve_to_report(infections, delay_rev_pmf, seeding_time);
3 changes: 3 additions & 0 deletions inst/stan/chunks/day_of_week_effect.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
if (week_effect > 1) {
reports = day_of_week_effect(reports, day_of_week, day_of_week_simplex);
}
6 changes: 6 additions & 0 deletions inst/stan/chunks/delay_rev_pmf.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
vector[delay_type_max[delay_id] + 1] delay_rev_pmf = get_delay_rev_pmf(
delay_id, delay_type_max[delay_id] + 1, delay_types_p, delay_types_id,
delay_types_groups, delay_max, delay_np_pmf,
delay_np_pmf_groups, delay_mean, delay_sd, delay_dist,
0, 1, 0
);
12 changes: 12 additions & 0 deletions inst/stan/chunks/delay_type_max.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
array[delay_types] int delay_type_max;
for (i in 1:delay_types) {
delay_type_max[i] = 0;
for (j in delay_types_groups[i]:(delay_types_groups[i + 1] - 1)) {
if (delay_types_p[j]) { // parametric
delay_type_max[i] += delay_max[delay_types_id[j]];
} else { // nonparametric
delay_type_max[i] += delay_np_pmf_groups[delay_types_id[j] + 1] -
delay_np_pmf_groups[delay_types_id[j]] - 1;
}
}
}
4 changes: 4 additions & 0 deletions inst/stan/chunks/delays_lp.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
delays_lp(
delay_mean, delay_mean_mean, delay_mean_sd, delay_sd, delay_sd_mean,
delay_sd_sd, delay_dist, delay_weight
);
4 changes: 4 additions & 0 deletions inst/stan/chunks/generate_infections.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
infections = generate_infections(
R, seeding_time, gt_rev_pmf, initial_infections, initial_growth, pop,
future_time
);
6 changes: 6 additions & 0 deletions inst/stan/chunks/gt_rev_pmf.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
vector[delay_type_max[gt_id] + 1] gt_rev_pmf = get_delay_rev_pmf(
gt_id, delay_type_max[gt_id] + 1, delay_types_p, delay_types_id,
delay_types_groups, delay_max, delay_np_pmf,
delay_np_pmf_groups, delay_mean, delay_sd, delay_dist,
1, 1, 0
);
1 change: 1 addition & 0 deletions inst/stan/chunks/impute_reports.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
imputed_reports = report_rng(reports, rep_phi, model_type);
3 changes: 3 additions & 0 deletions inst/stan/chunks/obs_scale_lp.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
if (obs_scale) {
frac_obs[1] ~ normal(obs_scale_mean, obs_scale_sd) T[0, 1];
}
3 changes: 3 additions & 0 deletions inst/stan/chunks/report_log_lik.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
log_lik = report_log_lik(
cases, obs_reports, rep_phi, model_type, obs_weight
);
3 changes: 3 additions & 0 deletions inst/stan/chunks/report_lp.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
report_lp(
cases, obs_reports, rep_phi, phi_mean, phi_sd, model_type, obs_weight
);
3 changes: 3 additions & 0 deletions inst/stan/chunks/scale_obs.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
if (obs_scale) {
reports = scale_obs(reports, frac_obs[1]);
}
7 changes: 7 additions & 0 deletions inst/stan/chunks/sim_vars.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
array[delay_n_p] real delay_mean = delay_mean_samples[i];
array[delay_n_p] real delay_sd = delay_sd_samples[i];
array[obs_scale] real frac_obs = frac_obs_samples[i];
array[model_type] real rep_phi = rep_phi_samples[i];
vector[week_effect] day_of_week_simplex = to_vector(
day_of_week_simplex_samples[i]
);
6 changes: 6 additions & 0 deletions inst/stan/chunks/trunc_rev_cmf.stan
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
vector[delay_type_max[trunc_id] + 1] trunc_rev_cmf = get_delay_rev_pmf(
trunc_id, delay_type_max[trunc_id] + 1, delay_types_p, delay_types_id,
delay_types_groups, delay_max, delay_np_pmf,
delay_np_pmf_groups, delay_mean, delay_sd, delay_dist,
0, 1, 1
);
4 changes: 2 additions & 2 deletions inst/stan/data/simulation_delays.stan
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
int<lower = 0> delay_n; // number of delay distribution distributions
int<lower = 0> delay_n_p; // number of parametric delay distributions
int<lower = 0> delay_n_np; // number of nonparametric delay distributions
array[n, delay_n_p] real delay_mean; // prior mean of mean delay distribution
array[n, delay_n_p] real<lower = 0> delay_sd; // prior sd of sd of delay distribution
array[n, delay_n_p] real delay_mean_samples; // prior mean of mean delay distribution
array[n, delay_n_p] real<lower = 0> delay_sd_samples; // prior sd of sd of delay distribution
array[delay_n_p] int<lower = 1> delay_max; // maximum delay distribution
array[delay_n_p] int<lower = 0> delay_dist; // 0 = lognormal; 1 = gamma
int<lower = 0> delay_np_pmf_length; // number of nonparametric pmf elements
Expand Down
6 changes: 3 additions & 3 deletions inst/stan/data/simulation_observation_model.stan
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
array[t - seeding_time] int day_of_week; // day of the week indicator (1 - 7)
int week_effect; // should a day of the week effect be estimated
array[n, week_effect] real<lower = 0> day_of_week_simplex;
array[n, week_effect] real<lower = 0> day_of_week_simplex_samples;
int obs_scale;
array[n, obs_scale] real<lower = 0, upper = 1> frac_obs;
array[n, obs_scale] real<lower = 0, upper = 1> frac_obs_samples;
int model_type;
array[n, model_type] real<lower = 0> rep_phi; // overdispersion of the reporting process
array[n, model_type] real<lower = 0> rep_phi_samples; // overdispersion of the reporting process
int<lower = 0> trunc_id; // id of truncation
10 changes: 5 additions & 5 deletions inst/stan/data/simulation_rt.stan
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
array[seeding_time ? n : 0, 1] real initial_infections; // initial logged infections
array[seeding_time > 1 ? n : 0, 1] real initial_growth; //initial growth
array[seeding_time ? n : 0, 1] real initial_infections_samples; // initial logged infections
array[seeding_time > 1 ? n : 0, 1] real initial_growth_samples; //initial growth

matrix[n, t - seeding_time] R; // reproduction number
int pop; // susceptible population
array[n, t - seeding_time] real R_samples; // reproduction number
int pop; // susceptible population

int<lower = 0> gt_id; // id of generation time
int<lower = 0> gt_id; // id of generation time
Loading

0 comments on commit 4470973

Please sign in to comment.