diff --git a/R/create.R b/R/create.R index a82316a36..35745b075 100644 --- a/R/create.R +++ b/R/create.R @@ -342,7 +342,7 @@ create_gp_data <- function(gp = gp_opts(), data) { #' create_obs_model(obs_opts(week_length = 3), dates = dates) create_obs_model <- function(obs = obs_opts(), dates) { data <- list( - model_type = ifelse(obs$family %in% "poisson", 0, 1), + obs_dist = ifelse(obs$family %in% "poisson", 0, 1), phi_mean = obs$phi[1], phi_sd = obs$phi[2], week_effect = ifelse(obs$week_effect, obs$week_length, 1), @@ -386,7 +386,7 @@ create_obs_model <- function(obs = obs_opts(), dates) { create_stan_data <- function(reported_cases, generation_time, rt, gp, obs, delays, horizon, backcalc, shifted_cases, - truncation) { + truncation, process_model) { ## make sure we have at least max_gt seeding time delays$seeding_time <- max(delays$seeding_time, generation_time$max) @@ -429,7 +429,8 @@ create_stan_data <- function(reported_cases, generation_time, gt_sd_mean = generation_time$sd, gt_sd_sd = generation_time$sd_sd, max_gt = generation_time$max, - burn_in = 0 + burn_in = 0, + process_model = process_model ) # add delay data data <- c(data, delays) @@ -530,7 +531,7 @@ create_initial_conditions <- function(data) { ) out$alpha <- array(truncnorm::rtruncnorm(1, a = 0, mean = 0, sd = data$alpha_sd)) } - if (data$model_type == 1) { + if (data$obs_dist == 1) { out$rep_phi <- array( truncnorm::rtruncnorm( 1, a = 0, mean = data$phi_mean, sd = data$phi_sd / 10 @@ -542,10 +543,10 @@ create_initial_conditions <- function(data) { if (data$seeding_time > 1) { out$initial_growth <- array(rnorm(1, data$prior_growth, 0.01)) } - out$log_R <- array(rnorm( + out$base_cov <- rnorm( n = 1, mean = convert_to_logmean(data$r_mean, data$r_sd), sd = convert_to_logsd(data$r_mean, data$r_sd) * 0.1 - )) + ) if (data$gt_mean_sd > 0) { out$gt_mean <- array(truncnorm::rtruncnorm(1, a = 0, mean = data$gt_mean_mean, diff --git a/R/estimate_infections.R b/R/estimate_infections.R index 60212df25..3dc27cc45 100644 --- a/R/estimate_infections.R +++ b/R/estimate_infections.R @@ -187,6 +187,7 @@ #' } estimate_infections <- function(reported_cases, generation_time, + model = "R", delays = delay_opts(), truncation = trunc_opts(), rt = rt_opts(), @@ -247,6 +248,10 @@ estimate_infections <- function(reported_cases, ) reported_cases <- reported_cases[-(1:backcalc$prior_window)] + model_choices <- c("infections", "growth", "R") + model <- match.arg(model_choices, choices = model_choices) + process_model <- which(model == model_choices) - 1 + # Define stan model parameters data <- create_stan_data( reported_cases = reported_cases, @@ -258,7 +263,8 @@ estimate_infections <- function(reported_cases, obs = obs, backcalc = backcalc, shifted_cases = shifted_cases$confirm, - horizon = horizon + horizon = horizon, + process_model = process_model ) # Set up default settings diff --git a/R/extract.R b/R/extract.R index 3f1820e24..4201ef393 100644 --- a/R/extract.R +++ b/R/extract.R @@ -154,7 +154,7 @@ extract_parameter_samples <- function(stan_fit, data, reported_dates, reported_i out$gt_sd <- extract_static_parameter("gt_sd", samples) out$gt_sd <- out$gt_sd[, value := value.V1][, value.V1 := NULL] } - if (data$model_type == 1) { + if (data$obs_dist == 1) { out$reporting_overdispersion <- extract_static_parameter("rep_phi", samples) out$reporting_overdispersion <- out$reporting_overdispersion[, value := value.V1][ , diff --git a/inst/stan/data/covariates.stan b/inst/stan/data/covariates.stan new file mode 100644 index 000000000..d96427e79 --- /dev/null +++ b/inst/stan/data/covariates.stan @@ -0,0 +1,3 @@ +int process_model; // 0 = infections; 1 = growth; 2 = rt +int bp_n; // no of breakpoints (0 = no breakpoints) +int breakpoints[t - seeding_time]; // when do breakpoints occur diff --git a/inst/stan/data/observation_model.stan b/inst/stan/data/observation_model.stan index e6350dfd3..bb527090c 100644 --- a/inst/stan/data/observation_model.stan +++ b/inst/stan/data/observation_model.stan @@ -1,5 +1,5 @@ int day_of_week[t - seeding_time]; // day of the week indicator (1 - 7) - int model_type; // type of model: 0 = poisson otherwise negative binomial + int obs_dist; // type of model: 0 = poisson otherwise negative binomial real phi_mean; // Mean and sd of the normal prior for the real phi_sd; // reporting process int week_effect; // length of week effect diff --git a/inst/stan/data/observations.stan b/inst/stan/data/observations.stan index 845a6751a..7c319fac5 100644 --- a/inst/stan/data/observations.stan +++ b/inst/stan/data/observations.stan @@ -1,5 +1,5 @@ int t; // unobserved time - int seeding_time; // time period used for seeding and not observed + int seeding_time; // time period used for seeding and not observed int horizon; // forecast horizon int future_time; // time in future for Rt int cases[t - horizon - seeding_time]; // observed cases diff --git a/inst/stan/data/rt.stan b/inst/stan/data/rt.stan index 9914b6872..834bd2b5d 100644 --- a/inst/stan/data/rt.stan +++ b/inst/stan/data/rt.stan @@ -3,8 +3,6 @@ real prior_growth; // prior on initial growth rate real r_mean; // prior mean of reproduction number real r_sd; // prior standard deviation of reproduction number - int bp_n; // no of breakpoints (0 = no breakpoints) - int breakpoints[t - seeding_time]; // when do breakpoints occur int future_fixed; // is underlying future Rt assumed to be fixed int fixed_from; // Reference date for when Rt estimation should be fixed int pop; // Initial susceptible population diff --git a/inst/stan/data/simulation_observation_model.stan b/inst/stan/data/simulation_observation_model.stan index 4554ad4b4..da02c0727 100644 --- a/inst/stan/data/simulation_observation_model.stan +++ b/inst/stan/data/simulation_observation_model.stan @@ -3,5 +3,5 @@ real day_of_week_simplex[n, week_effect]; int obs_scale; real frac_obs[n, obs_scale]; - int model_type; - real rep_phi[n, model_type]; // overdispersion of the reporting process + int obs_dist; + real rep_phi[n, obs_dist]; // overdispersion of the reporting process diff --git a/inst/stan/estimate_infections.stan b/inst/stan/estimate_infections.stan index 79700631f..668a1d99b 100644 --- a/inst/stan/estimate_infections.stan +++ b/inst/stan/estimate_infections.stan @@ -2,7 +2,7 @@ functions { #include functions/pmfs.stan #include functions/convolve.stan #include functions/gaussian_process.stan -#include functions/rt.stan +#include functions/covariates.stan #include functions/infections.stan #include functions/observation_model.stan #include functions/generated_quantities.stan @@ -12,6 +12,7 @@ functions { data { #include data/observations.stan #include data/delays.stan +#include data/covariates.stan #include data/gaussian_process.stan #include data/generation_time.stan #include data/rt.stan @@ -36,8 +37,7 @@ parameters{ real rho[fixed ? 0 : 1]; // length scale of noise GP real alpha[fixed ? 0 : 1]; // scale of of noise GP vector[fixed ? 0 : M] eta; // unconstrained noise - // Rt - vector[estimate_r] log_R; // baseline reproduction number estimate (log) + real base_cov; // covariate (R/r/inf) real initial_infections[estimate_r] ; // seed infections real initial_growth[estimate_r && seeding_time > 1 ? 1 : 0]; // seed growth rate real gt_mean[estimate_r && gt_mean_sd > 0]; // mean of generation time (if uncertain) @@ -51,31 +51,36 @@ parameters{ real frac_obs[obs_scale]; // fraction of cases that are ultimately observed real truncation_mean[truncation]; // mean of truncation real truncation_sd[truncation]; // sd of truncation - real rep_phi[model_type]; // overdispersion of the reporting process + real rep_phi[obs_dist]; // overdispersion of the reporting process } transformed parameters { vector[fixed ? 0 : noise_terms] noise; // noise generated by the gaussian process - vector[estimate_r > 0 ? ot_h : 0] R; // reproduction number + vector[seeding_time] uobs_inf; vector[t] infections; // latent infections + vector[ot_h] cov; // covaraites vector[ot_h] reports; // estimated reported cases vector[ot] obs_reports; // observed estimated reported cases // GP in noise - spectral densities if (!fixed) { noise = update_gp(PHI, M, L, alpha[1], rho[1], eta, gp_type); } + // update covariates + cov = update_covariate(base_cov, noise, breakpoints, bp_effects, stationary, ot_h, 0); + uobs_inf = generate_seed(initial_infections, initial_growth, seeding_time); // Estimate latent infections - if (estimate_r) { + if (process_model == 0) { + // via deconvolution + infections = infection_model(cov, uobs_inf, future_time); + } else if (process_model == 1) { + // via growth + infections = growth_model(cov, uobs_inf, future_time); + } else if (process_model == 2) { // via Rt real set_gt_mean = (gt_mean_sd > 0 ? gt_mean[1] : gt_mean_mean); real set_gt_sd = (gt_sd_sd > 0 ? gt_sd[1] : gt_sd_mean); - R = update_Rt(R, log_R[estimate_r], noise, breakpoints, bp_effects, stationary); - infections = generate_infections(R, seeding_time, set_gt_mean, set_gt_sd, max_gt, - initial_infections, initial_growth, - pop, future_time); - } else { - // via deconvolution - infections = deconvolve_infections(shifted_cases, noise, fixed, backcalc_prior); + infections = renewal_model(cov, uobs_inf, set_gt_mean, set_gt_sd, max_gt, + pop, future_time); } // convolve from latent infections to mean of observations reports = convolve_to_report(infections, delay_mean, delay_sd, max_delay, seeding_time); @@ -98,18 +103,23 @@ model { gaussian_process_lp(rho[1], alpha[1], eta, ls_meanlog, ls_sdlog, ls_min, ls_max, alpha_sd); } - // penalised priors for delay distributions - delays_lp(delay_mean, delay_mean_mean, delay_mean_sd, delay_sd, delay_sd_mean, delay_sd_sd, t); + if (delays > 0) { + // penalised priors for delay distributions + delays_lp(delay_mean, delay_mean_mean, delay_mean_sd, delay_sd, delay_sd_mean, delay_sd_sd, t); + } // priors for truncation - truncation_lp(truncation_mean, truncation_sd, trunc_mean_mean, trunc_mean_sd, - trunc_sd_mean, trunc_sd_sd); + if (truncation) { + truncation_lp(truncation_mean, truncation_sd, trunc_mean_mean, trunc_mean_sd, + trunc_sd_mean, trunc_sd_sd); + } if (estimate_r) { - // priors on Rt - rt_lp(log_R, initial_infections, initial_growth, bp_effects, bp_sd, bp_n, seeding_time, - r_logmean, r_logsd, prior_infections, prior_growth); // penalised_prior on generation interval generation_time_lp(gt_mean, gt_mean_mean, gt_mean_sd, gt_sd, gt_sd_mean, gt_sd_sd, ot); } + // priors on Rt + covariate_lp(base_cov, bp_effects, bp_sd, bp_n, r_logmean, r_logsd); + infections_lp(initial_infections, initial_growth, prior_infections, prior_growth, + seeding_time); // prior observation scaling if (obs_scale) { frac_obs[1] ~ normal(obs_scale_mean, obs_scale_sd) T[0,]; @@ -117,13 +127,13 @@ model { // observed reports from mean of reports (update likelihood) if (likelihood) { report_lp(cases, obs_reports, rep_phi, phi_mean, phi_sd, - model_type, obs_weight); + obs_dist, obs_weight); } } generated quantities { int imputed_reports[ot_h]; - vector[estimate_r > 0 ? 0: ot_h] gen_R; + vector[estimate_r > 0 ? 0: ot_h] R; real r[ot_h - 1]; vector[return_likelihood > 1 ? ot : 0] log_lik; if (estimate_r == 0){ @@ -131,16 +141,16 @@ generated quantities { real gt_mean_sample = (gt_mean_sd > 0 ? normal_rng(gt_mean_mean, gt_mean_sd) : gt_mean_mean); real gt_sd_sample = (gt_sd_sd > 0 ? normal_rng(gt_sd_mean, gt_sd_sd) : gt_sd_mean); // calculate Rt using infections and generation time - gen_R = calculate_Rt(infections, seeding_time, gt_mean_sample, gt_sd_sample, + R = calculate_Rt(infections, seeding_time, gt_mean_sample, gt_sd_sample, max_gt, rt_half_window); } // estimate growth from infections r = calculate_growth(infections, seeding_time + 1); // simulate reported cases - imputed_reports = report_rng(reports, rep_phi, model_type); + imputed_reports = report_rng(reports, rep_phi, obs_dist); // log likelihood of model if (return_likelihood) { - log_lik = report_log_lik(cases, obs_reports, rep_phi, model_type, + log_lik = report_log_lik(cases, obs_reports, rep_phi, obs_dist, obs_weight); } } diff --git a/inst/stan/estimate_secondary.stan b/inst/stan/estimate_secondary.stan index 7c4dbdab4..397cdd2d6 100644 --- a/inst/stan/estimate_secondary.stan +++ b/inst/stan/estimate_secondary.stan @@ -23,7 +23,7 @@ parameters{ real frac_obs[obs_scale]; // fraction of cases that are ultimately observed real truncation_mean[truncation]; // mean of truncation real truncation_sd[truncation]; // sd of truncation - real rep_phi[model_type]; // overdispersion of the reporting process + real rep_phi[obs_dist]; // overdispersion of the reporting process } transformed parameters { @@ -54,7 +54,7 @@ model { // observed secondary reports from mean of secondary reports (update likelihood) if (likelihood) { report_lp(obs[(burn_in + 1):t], secondary[(burn_in + 1):t], - rep_phi, phi_mean, phi_sd, model_type, 1); + rep_phi, phi_mean, phi_sd, obs_dist, 1); } } @@ -62,10 +62,10 @@ generated quantities { int sim_secondary[t - burn_in]; vector[return_likelihood > 1 ? t - burn_in : 0] log_lik; // simulate secondary reports - sim_secondary = report_rng(secondary[(burn_in + 1):t], rep_phi, model_type); + sim_secondary = report_rng(secondary[(burn_in + 1):t], rep_phi, obs_dist); // log likelihood of model if (return_likelihood) { log_lik = report_log_lik(obs[(burn_in + 1):t], secondary[(burn_in + 1):t], - rep_phi, model_type, obs_weight); + rep_phi, obs_dist, obs_weight); } } diff --git a/inst/stan/functions/covariates.stan b/inst/stan/functions/covariates.stan index 7e1986ec3..bfaaccf03 100644 --- a/inst/stan/functions/covariates.stan +++ b/inst/stan/functions/covariates.stan @@ -40,21 +40,15 @@ vector update_covariate(real base_cov, vector noise, int[] bps, } return(cov); } -// Rt priors -void covariate_lp(vector base_cov, real[] initial_infections, real[] initial_growth, - real[] bp_effects, real[] bp_sd, int bp_n, int seeding_time, - real r_logmean, real r_logsd, real prior_infections, - real prior_growth) { + +void covariate_lp(real base_cov, + real[] bp_effects, real[] bp_sd, int bp_n, + real r_logmean, real r_logsd) { // initial prior - base_cov ~ normal(base_logmean, base_logsd); + base_cov ~ normal(r_logmean, r_logsd); //breakpoint effects on Rt if (bp_n > 0) { bp_sd[1] ~ normal(0, 0.1) T[0,]; bp_effects ~ normal(0, bp_sd[1]); } - // initial infections - initial_infections ~ normal(prior_infections, 0.2); - if (seeding_time > 1) { - initial_growth ~ normal(prior_growth, 0.2); - } } diff --git a/inst/stan/functions/infections.stan b/inst/stan/functions/infections.stan index 2ebb8225f..1863530ad 100644 --- a/inst/stan/functions/infections.stan +++ b/inst/stan/functions/infections.stan @@ -23,15 +23,15 @@ vector generate_seed(real[] initial_infections, real[] initial_growth, int uot) seed_infs[s] = exp(initial_infections[1] + initial_growth[1] * (s - 1)); } } - return(seed_infs) + return(seed_infs); } // generate infections using infectiousness -vector renewal_model(vector oR, vector uobs_infs +vector renewal_model(vector oR, vector uobs_inf, real gt_mean, real gt_sd, int max_gt, int pop, int ht) { // time indices and storage int ot = num_elements(oR); - int uot = num_elements(uobs_infs); + int uot = num_elements(uobs_inf); int nht = ot - ht; int t = ot + uot; vector[ot] R = oR; @@ -39,16 +39,16 @@ vector renewal_model(vector oR, vector uobs_infs vector[t] infections = rep_vector(1e-5, t); vector[ot] cum_infections = rep_vector(0, ot); vector[ot] infectiousness = rep_vector(1e-5, ot); + // generation time pmf + vector[max_gt] gt_pmf = rep_vector(1e-5, max_gt); + int gt_indexes[max_gt]; // Initialise infections - infections[1:uot] = uobs_infs; + infections[1:uot] = uobs_inf; // calculate cumulative infections if (pop) { cum_infections[1] = sum(infections[1:uot]); } - // generation time pmf - vector[max_gt] gt_pmf = rep_vector(1e-5, max_gt); // revert indices (this is for later doing the convolution with recent cases) - int gt_indexes[max_gt]; for (i in 1:(max_gt)) { gt_indexes[i] = max_gt - i + 1; } @@ -76,41 +76,38 @@ vector renewal_model(vector oR, vector uobs_infs return(infections); } // update infections using a growth model (linear,log, or non-parametric growth) -vector growth_model(vector r, int ht, vector uobs_infs, - int prior, vector constant) { +vector growth_model(vector r, vector uobs_inf, int ht) { // time indices and storage int ot = num_elements(r); - int uot = num_elements(seed_infections); + int uot = num_elements(uobs_inf); int nht = ot - ht; int t = ot + uot; vector[t] infections = rep_vector(1e-5, t); vector[ot] obs_inf; // Update observed infections - if (link == 0) { - if (prior == 1) { - obs_inf = constant .* r; - }else if (prior == 2) { - obs_inf[1] = uobs_inf[uot] * r[1]; - for (i in 2:t) { - obs_inf[i] = obs_inf[i - 1] * r[i]; - } - } - }else if (link == 1) { - if (prior == 1) { - obs_inf = constant + r; - }else if (prior == 2) { - obs_inf[1] = log(uobs_inf[uot]) + r[1]; - for (i in 2:t) { - obs_inf[i] = obs_inf[i - 1] + r[i]; - } - } - obs_inf = exp(obs_inf); + obs_inf[1] = uobs_inf[uot] * r[1]; + for (i in 2:t) { + obs_inf[i] = obs_inf[i - 1] * r[i]; } - infections[1:uot] = infections[1:uot] + uobs_inf; - infections[(uot + 1):t] = infections[(uot + 1):t] + obs_inf; + infections[1:uot] = infections[1:uot] + uobs_inf; + infections[(uot + 1):t] = infections[(uot + 1):t] + obs_inf; + return(infections); +} +// update infections using a growth model (linear,log, or non-parametric growth) +vector infection_model(vector cov, vector uobs_inf, int ht) { + // time indices and storage + int ot = num_elements(cov); + int uot = num_elements(uobs_inf); + int nht = ot - ht; + int t = ot + uot; + vector[t] infections = rep_vector(1e-5, t); + vector[ot] obs_inf = cov; + infections[1:uot] = infections[1:uot] + uobs_inf; + infections[(uot + 1):t] = infections[(uot + 1):t] + obs_inf; return(infections); } -// Update the log density for the generation time distribution mean and sd + +/// Update the log density for the generation time distribution mean and sd void generation_time_lp(real[] gt_mean, real gt_mean_mean, real gt_mean_sd, real[] gt_sd, real gt_sd_mean, real gt_sd_sd, int weight) { if (gt_mean_sd > 0) { @@ -121,3 +118,12 @@ void generation_time_lp(real[] gt_mean, real gt_mean_mean, real gt_mean_sd, } } +void infections_lp(real[] initial_infections, real[] initial_growth, + real prior_infections, real prior_growth, + int seeding_time) { + // initial infections + initial_infections ~ normal(prior_infections, 0.2); + if (seeding_time > 1) { + initial_growth ~ normal(prior_growth, 0.2); + } +} diff --git a/inst/stan/functions/observation_model.stan b/inst/stan/functions/observation_model.stan index 25014c21c..c57f44d7f 100644 --- a/inst/stan/functions/observation_model.stan +++ b/inst/stan/functions/observation_model.stan @@ -67,12 +67,12 @@ void truncation_lp(real[] truncation_mean, real[] truncation_sd, // update log density for reported cases void report_lp(int[] cases, vector reports, real[] rep_phi, real phi_mean, real phi_sd, - int model_type, real weight) { + int obs_dist, real weight) { real sqrt_phi = 1e5; - if (model_type) { + if (obs_dist) { // the reciprocal overdispersion parameter (phi) - rep_phi[model_type] ~ normal(phi_mean, phi_sd) T[0,]; - sqrt_phi = 1 / sqrt(rep_phi[model_type]); + rep_phi[obs_dist] ~ normal(phi_mean, phi_sd) T[0,]; + sqrt_phi = 1 / sqrt(rep_phi[obs_dist]); // defer to poisson if phi is large, to avoid overflow or // if poisson specified } @@ -93,13 +93,13 @@ void report_lp(int[] cases, vector reports, } // update log likelihood (as above but not vectorised and returning log likelihood) vector report_log_lik(int[] cases, vector reports, - real[] rep_phi, int model_type, real weight) { + real[] rep_phi, int obs_dist, real weight) { int t = num_elements(reports); vector[t] log_lik; real sqrt_phi = 1e5; - if (model_type) { + if (obs_dist) { // the reciprocal overdispersion parameter (phi) - sqrt_phi = 1 / sqrt(rep_phi[model_type]); + sqrt_phi = 1 / sqrt(rep_phi[obs_dist]); } // defer to poisson if phi is large, to avoid overflow @@ -115,12 +115,12 @@ vector report_log_lik(int[] cases, vector reports, return(log_lik); } // sample reported cases from the observation model -int[] report_rng(vector reports, real[] rep_phi, int model_type) { +int[] report_rng(vector reports, real[] rep_phi, int obs_dist) { int t = num_elements(reports); int sampled_reports[t]; real sqrt_phi = 1e5; - if (model_type) { - sqrt_phi = 1 / sqrt(rep_phi[model_type]); + if (obs_dist) { + sqrt_phi = 1 / sqrt(rep_phi[obs_dist]); } for (s in 1:t) { diff --git a/inst/stan/simulate_infections.stan b/inst/stan/simulate_infections.stan index 87875d85e..8275990a5 100644 --- a/inst/stan/simulate_infections.stan +++ b/inst/stan/simulate_infections.stan @@ -2,7 +2,6 @@ functions { #include functions/pmfs.stan #include functions/convolve.stan #include functions/gaussian_process.stan -#include functions/rt.stan #include functions/infections.stan #include functions/observation_model.stan #include functions/generated_quantities.stan @@ -28,11 +27,12 @@ generated quantities { matrix[n, t - seeding_time] reports; // observed cases int imputed_reports[n, t - seeding_time]; real r[n, t - seeding_time - 1]; + vector[seeding_time] uobs_inf; for (i in 1:n) { - // generate infections from Rt trace - infections[i] = to_row_vector(generate_infections(to_vector(R[i]), seeding_time, + uobs_inf = generate_seed(initial_infections[i], initial_growth[i], seeding_time); + // generate infections from Rt trace + infections[i] = to_row_vector(renewal_model(to_vector(R[i]), uobs_inf, gt_mean[i, 1], gt_sd[i, 1], max_gt, - initial_infections[i], initial_growth[i], pop, future_time)); // convolve from latent infections to mean of observations reports[i] = to_row_vector(convolve_to_report(to_vector(infections[i]), delay_mean[i], @@ -48,7 +48,7 @@ generated quantities { reports[i] = to_row_vector(scale_obs(to_vector(reports[i]), frac_obs[i, 1])); } // simulate reported cases - imputed_reports[i] = report_rng(to_vector(reports[i]), rep_phi[i], model_type); + imputed_reports[i] = report_rng(to_vector(reports[i]), rep_phi[i], obs_dist); r[i] = calculate_growth(to_vector(infections[i]), seeding_time + 1); } } diff --git a/inst/stan/simulate_secondary.stan b/inst/stan/simulate_secondary.stan index f7290fe49..5d9d1a99f 100644 --- a/inst/stan/simulate_secondary.stan +++ b/inst/stan/simulate_secondary.stan @@ -36,6 +36,6 @@ generated quantities { secondary = day_of_week_effect(secondary, day_of_week, to_vector(day_of_week_simplex[i])); } // simulate secondary reports - sim_secondary[i] = report_rng(tail(secondary, all_dates ? t : h), rep_phi[i], model_type); + sim_secondary[i] = report_rng(tail(secondary, all_dates ? t : h), rep_phi[i], obs_dist); } } diff --git a/tests/testthat/test-create_obs_model.R b/tests/testthat/test-create_obs_model.R index eb0135941..5e2a44d5f 100644 --- a/tests/testthat/test-create_obs_model.R +++ b/tests/testthat/test-create_obs_model.R @@ -5,12 +5,12 @@ test_that("create_obs_model works with default settings", { obs <- create_obs_model(dates = dates) expect_equal(length(obs), 11) expect_equal(names(obs), c( - "model_type", "phi_mean", "phi_sd", "week_effect", "obs_weight", + "obs_dist", "phi_mean", "phi_sd", "week_effect", "obs_weight", "obs_scale", "likelihood", "return_likelihood", "day_of_week", "obs_scale_mean", "obs_scale_sd" )) - expect_equal(obs$model_type, 1) + expect_equal(obs$obs_dist, 1) expect_equal(obs$week_effect, 7) expect_equal(obs$obs_scale, 0) expect_equal(obs$likelihood, 1) @@ -22,7 +22,7 @@ test_that("create_obs_model works with default settings", { test_that("create_obs_model can be used with a Poisson model", { obs <- create_obs_model(dates = dates, obs = obs_opts(family = "poisson")) - expect_equal(obs$model_type, 0) + expect_equal(obs$obs_dist, 0) }) test_that("create_obs_model can be used with a scaling", { @@ -51,4 +51,4 @@ test_that("create_obs_model can be used with a user set phi", { expect_equal(obs$phi_sd, 0.1) expect_error(obs_opts(phi = c(10))) expect_error(obs_opts(phi = c("Hi", "World"))) -}) \ No newline at end of file +})