Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #408: Fit the susceptible population #904

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
estimate_infections()
```

- Added support for fitting the susceptible population size. By @seabbs in #904 and reviewed by @sbfnk.
- A bug was fixed where the initial growth was never estimated (i.e. the prior mean was always zero). By @sbfnk in #853 and reviewed by @seabbs.
- A bug was fixed where an internal function for applying a default cdf cutoff failed due to a difference a vector length issue. By @jamesmbaazam in #858 and reviewed by @sbfnk.
- All parameters have been changed to the new parameter interface. By @sbfnk in #871 and #890 and reviewed by @seabbs.
Expand Down
7 changes: 5 additions & 2 deletions R/create.R
Original file line number Diff line number Diff line change
Expand Up @@ -263,7 +263,8 @@ create_rt_data <- function(rt = rt_opts(), breakpoints = NULL,
breakpoints = breakpoints,
future_fixed = as.numeric(future_rt$fixed),
fixed_from = future_rt$from,
pop = rt$pop,
use_pop =
as.integer(rt$pop != Fixed(0)) + as.integer(rt$pop_period == "all"),
stationary = as.numeric(rt$gp_on == "R0"),
future_time = horizon - future_rt$from
)
Expand Down Expand Up @@ -584,12 +585,14 @@ create_stan_data <- function(data, seeding_time, rt, gp, obs, backcalc,
R0 = rt$prior,
frac_obs = obs$scale,
rep_phi = obs$phi,
pop = rt$pop,
lower_bounds = c(
alpha = 0,
rho = 0,
R0 = 0,
frac_obs = 0,
rep_phi = 0
rep_phi = 0,
pop = 0
)
)
)
Expand Down
41 changes: 34 additions & 7 deletions R/opts.R
Original file line number Diff line number Diff line change
Expand Up @@ -320,11 +320,17 @@ trunc_opts <- function(dist = Fixed(0), default_cdf_cutoff = 0.001,
#' conservative estimate of break point changes (alter this by setting
#' `gp = NULL`).
#'
#' @param pop Integer, defaults to 0. Susceptible population initially present.
#' Used to adjust Rt estimates when otherwise fixed based on the proportion of
#' the population that is susceptible. When set to 0 no population adjustment
#' @param pop A `<dist_spec>` giving the initial susceptible population size.
#' Used to adjust Rt estimates based on the proportion of the population that
#' is susceptible. Defaults to `Fixed(0)` which means no population adjustment
#' is done.
#'
#' @param pop_period Character string, defaulting to "forecast". Controls when
#' susceptible population adjustment is applied. "forecast" only applies the
#' adjustment to forecasts while "all" applies it to both data and forecasts.
#' Note that with "forecast", Rt estimates are unadjusted for susceptible
#' depletion but posterior predictions are adjusted.
Comment on lines +331 to +332
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
#' Note that with "forecast", Rt estimates are unadjusted for susceptible
#' depletion but posterior predictions are adjusted.
#' Note that with "all", Rt estimates are unadjusted for susceptible
#' depletion but posterior predictions of infections and reports are
#' adjusted.

#'
#' @param gp_on Character string, defaulting to "R_t-1". Indicates how the
#' Gaussian process, if in use, should be applied to Rt. Currently supported
#' options are applying the Gaussian process to the last estimated Rt (i.e
Expand Down Expand Up @@ -354,14 +360,15 @@ rt_opts <- function(prior = LogNormal(mean = 1, sd = 1),
use_breakpoints = TRUE,
future = "latest",
gp_on = c("R_t-1", "R0"),
pop = 0) {
pop = Fixed(0),
pop_period = c("forecast", "all")) {
rt <- list(
use_rt = use_rt,
rw = rw,
use_breakpoints = use_breakpoints,
future = future,
pop = pop,
gp_on = arg_match(gp_on)
gp_on = arg_match(gp_on),
pop_period = arg_match(pop_period)
)

# replace default settings with those specified by user
Expand All @@ -388,6 +395,23 @@ rt_opts <- function(prior = LogNormal(mean = 1, sd = 1),
prior <- LogNormal(mean = prior$mean, sd = prior$sd)
}

if (is.numeric(pop)) {
lifecycle::deprecate_warn(
"1.7.0",
"rt_opts(pop = 'must be a `<dist_spec>`')",
details = "For specifying a fixed population size, use `Fixed(pop)`"
)
pop <- Fixed(pop)
}
rt$pop <- pop
if (rt$pop_period == "all" && pop == Fixed(0)) {
cli_abort(
c(
"!" = "pop_period = \"all\" but pop is fixed at 0."
)
)
}

if (rt$use_rt) {
rt$prior <- prior
} else {
Expand Down Expand Up @@ -724,7 +748,10 @@ obs_opts <- function(family = c("negbin", "poisson"),
cli_abort(
c(
"!" = "Specifying {.var phi} as a vector of length 2 is deprecated.",
"i" = "Mean and SD should be given as list elements."
"i" = paste0(
"Use a {.cls dist_spec} instead, e.g. Normal(mean = {phi[1]}, ",
"sd = {phi[2]})."
)
)
)
}
Expand Down
9 changes: 5 additions & 4 deletions R/simulate_infections.R
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ simulate_infections <- function(estimates, R, initial_infections,
CrIs = c(0.2, 0.5, 0.9),
backend = "rstan",
seeding_time = NULL,
pop = 0, ...) {
pop = Fixed(0), ...) {
## deprecated usage
if (!missing(estimates)) {
deprecate_stop(
Expand All @@ -86,14 +86,14 @@ simulate_infections <- function(estimates, R, initial_infections,
assert_numeric(R$R, lower = 0)
assert_numeric(initial_infections, lower = 0)
assert_numeric(day_of_week_effect, lower = 0, null.ok = TRUE)
assert_numeric(pop, lower = 0)
if (!is.null(seeding_time)) {
assert_integerish(seeding_time, lower = 1)
}
assert_class(delays, "delay_opts")
assert_class(truncation, "trunc_opts")
assert_class(obs, "obs_opts")
assert_class(generation_time, "generation_time_opts")
assert_class(pop, "dist_spec")

## create R for all dates modelled
all_dates <- data.table(date = seq.Date(min(R$date), max(R$date), by = "day"))
Expand Down Expand Up @@ -125,7 +125,7 @@ simulate_infections <- function(estimates, R, initial_infections,
initial_infections = array(log_initial_infections, dim = c(1, 1)),
initial_growth = array(initial_growth, dim = c(1, length(initial_growth))),
R = array(R$R, dim = c(1, nrow(R))),
pop = pop
use_pop = as.integer(pop != Fixed(0))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can be a future issue but I think we want to be able to have susceptible depletion across full simulations, too.

)

data <- c(data, create_stan_delays(
Expand Down Expand Up @@ -179,7 +179,8 @@ simulate_infections <- function(estimates, R, initial_infections,
rho = NULL,
R0 = NULL,
frac_obs = obs$scale,
rep_phi = obs$phi
rep_phi = obs$phi,
pop = pop
))
## set empty params matrix - variable parameters not supported here
data$params <- array(dim = c(1, 0))
Expand Down
1 change: 1 addition & 0 deletions inst/stan/data/estimate_infections_params.stan
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ int<lower = 0> rho_id; // parameter id of rho (GP lengthscale)
int<lower = 0> R0_id; // parameter id of R0
int<lower = 0> frac_obs_id; // parameter id of frac_obs
int<lower = 0> rep_phi_id; // parameter id of rep_phi_id
int<lower = 0> pop_id; // parameter id of pop
2 changes: 1 addition & 1 deletion inst/stan/data/rt.stan
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,5 @@
array[t - seeding_time] int breakpoints; // when do breakpoints occur
int future_fixed; // is underlying future Rt assumed to be fixed
int fixed_from; // Reference date for when Rt estimation should be fixed
int pop; // Initial susceptible population
int use_pop; // use population size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int use_pop; // use population size
int use_pop; // use population size (0 = no; 1 = forecasts; 2 = all)

int<lower = 0> gt_id; // id of generation time
2 changes: 1 addition & 1 deletion inst/stan/data/simulation_rt.stan
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,6 @@
array[n, seeding_time > 1 ? 1 : 0] real initial_growth; //initial growth

matrix[n, t - seeding_time] R; // reproduction number
int pop; // susceptible population
int use_pop; // use population size
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
int use_pop; // use population size
int use_pop; // use population size (0 = no; 1 = forecasts)


int<lower = 0> gt_id; // id of generation time
6 changes: 5 additions & 1 deletion inst/stan/estimate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -107,9 +107,13 @@ transformed parameters {
frac_obs_id, params_fixed_lookup, params_variable_lookup, params_value,
params
);
real pop = get_param(
pop_id, params_fixed_lookup, params_variable_lookup, params_value,
params
);
infections = generate_infections(
R, seeding_time, gt_rev_pmf, initial_infections, initial_growth, pop,
future_time, obs_scale, frac_obs
use_pop, future_time, obs_scale, frac_obs
);
}
} else {
Expand Down
8 changes: 4 additions & 4 deletions inst/stan/functions/infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ real update_infectiousness(vector infections, vector gt_rev_pmf,
// generate infections by using Rt = Rt-1 * sum(reversed generation time pmf * infections)
vector generate_infections(vector oR, int uot, vector gt_rev_pmf,
array[] real initial_infections, array[] real initial_growth,
int pop, int ht, int obs_scale, real frac_obs) {
real pop, int use_pop, int ht, int obs_scale, real frac_obs) {
// time indices and storage
int ot = num_elements(oR);
int nht = ot - ht;
Expand All @@ -42,20 +42,20 @@ vector generate_infections(vector oR, int uot, vector gt_rev_pmf,
}
}
// calculate cumulative infections
if (pop) {
if (use_pop) {
cum_infections[1] = sum(infections[1:uot]);
}
// iteratively update infections
for (s in 1:ot) {
infectiousness[s] = update_infectiousness(infections, gt_rev_pmf, uot, s);
if (pop && s > nht) {
if (use_pop == 1 || (use_pop == 2 && s <= nht)) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry I'm not still getting the logic. We have:
use_pop == 1: only adjust in forecast
use_pop == 2: adjust across all

So shouldn't this be

Suggested change
if (use_pop == 1 || (use_pop == 2 && s <= nht)) {
if ((use_pop == 1 && s > nht) || use_pop == 2)) {

exp_adj_Rt = exp(-R[s] * infectiousness[s] / (pop - cum_infections[nht]));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
exp_adj_Rt = exp(-R[s] * infectiousness[s] / (pop - cum_infections[nht]));
exp_adj_Rt = exp(-R[s] * infectiousness[s] / (pop - cum_infections[s]));

exp_adj_Rt = exp_adj_Rt > 1 ? 1 : exp_adj_Rt;
infections[s + uot] = (pop - cum_infections[s]) * (1 - exp_adj_Rt);
}else{
infections[s + uot] = R[s] * infectiousness[s];
}
if (pop && s < ot) {
if (use_pop && s < ot) {
cum_infections[s + 1] = cum_infections[s] + infections[s + uot];
}
}
Expand Down
8 changes: 7 additions & 1 deletion inst/stan/simulate_infections.stan
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@ generated quantities {
frac_obs_id, params_fixed_lookup, params_variable_lookup,
params_value, params
);

vector[n] pop = get_param(
pop_id, params_fixed_lookup, params_variable_lookup,
params_value, params
);

for (i in 1:n) {
// generate infections from Rt trace
vector[delay_type_max[gt_id] + 1] gt_rev_pmf;
Expand All @@ -62,7 +68,7 @@ generated quantities {

infections[i] = to_row_vector(generate_infections(
to_vector(R[i]), seeding_time, gt_rev_pmf, initial_infections[i],
initial_growth[i], pop, future_time, obs_scale, frac_obs[i]
initial_growth[i], pop[i], use_pop, future_time, obs_scale, frac_obs[i]
));

if (delay_id) {
Expand Down
2 changes: 1 addition & 1 deletion man/create_forecast_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/create_stan_data.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/epinow.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/estimate_infections.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 0 additions & 1 deletion man/forecast_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/regional_epinow.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 11 additions & 4 deletions man/rt_opts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 4 additions & 4 deletions man/simulate_infections.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions tests/testthat/test-create_rt_date.R
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ test_that("create_rt_data returns expected default values", {
expect_equal(result$breakpoints, numeric(0))
expect_equal(result$future_fixed, 1)
expect_equal(result$fixed_from, 0)
expect_equal(result$pop, 0)
expect_equal(result$use_pop, 0)
expect_equal(result$stationary, 0)
expect_equal(result$future_time, 0)
})
Expand All @@ -27,13 +27,13 @@ test_that("create_rt_data handles custom rt_opts correctly", {
use_breakpoints = FALSE,
future = "project",
gp_on = "R0",
pop = 1000000
pop = Normal(mean = 1000000, sd = 100)
)

result <- create_rt_data(rt = custom_rt, horizon = 7)

expect_equal(result$estimate_r, 0)
expect_equal(result$pop, 1000000)
expect_equal(result$use_pop, 1)
expect_equal(result$stationary, 1)
expect_equal(result$future_time, 7)
})
Expand Down
Loading
Loading