diff --git a/inst/stan/functions/infections.stan b/inst/stan/functions/infections.stan index 015d07168..c02d9e22e 100644 --- a/inst/stan/functions/infections.stan +++ b/inst/stan/functions/infections.stan @@ -27,14 +27,15 @@ vector generate_infections(vector oR, int uot, vector gt_rev_pmf, int t = ot + uot; vector[ot] R = oR; real exp_adj_Rt; - vector[t] infections = rep_vector(1e-5, t); - vector[ot] cum_infections = rep_vector(0, ot); - vector[ot] infectiousness = rep_vector(1e-5, ot); + vector[t] infections = rep_vector(0, t); + vector[ot] cum_infections; + vector[ot] infectiousness; // Initialise infections using daily growth infections[1] = exp(initial_infections[1]); if (uot > 1) { + real growth = exp(initial_growth[1]); for (s in 2:uot) { - infections[s] = exp(initial_infections[1] + initial_growth[1] * (s - 1)); + infections[s] = infections[s - 1] * growth; } } // calculate cumulative infections @@ -43,13 +44,13 @@ vector generate_infections(vector oR, int uot, vector gt_rev_pmf, } // iteratively update infections for (s in 1:ot) { - infectiousness[s] += update_infectiousness(infections, gt_rev_pmf, uot, s); + infectiousness[s] = update_infectiousness(infections, gt_rev_pmf, uot, s); if (pop && s > nht) { exp_adj_Rt = exp(-R[s] * infectiousness[s] / (pop - cum_infections[nht])); exp_adj_Rt = exp_adj_Rt > 1 ? 1 : exp_adj_Rt; infections[s + uot] = (pop - cum_infections[s]) * (1 - exp_adj_Rt); }else{ - infections[s + uot] += R[s] * infectiousness[s]; + infections[s + uot] = R[s] * infectiousness[s]; } if (pop && s < ot) { cum_infections[s + 1] = cum_infections[s] + infections[s + uot];