From 1622195087a6d52aba1eb415bc0815cd0e81012c Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Tue, 4 Jul 2023 14:07:57 +0100 Subject: [PATCH 1/2] reduce number of exponential calls --- inst/stan/functions/infections.stan | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/inst/stan/functions/infections.stan b/inst/stan/functions/infections.stan index 015d07168..7e5330081 100644 --- a/inst/stan/functions/infections.stan +++ b/inst/stan/functions/infections.stan @@ -33,8 +33,9 @@ vector generate_infections(vector oR, int uot, vector gt_rev_pmf, // Initialise infections using daily growth infections[1] = exp(initial_infections[1]); if (uot > 1) { + real growth = exp(initial_growth[1]); for (s in 2:uot) { - infections[s] = exp(initial_infections[1] + initial_growth[1] * (s - 1)); + infections[s] = infections[s - 1] * growth; } } // calculate cumulative infections From fefe19b4bada5c0abe228abcdcc816c9e4870953 Mon Sep 17 00:00:00 2001 From: Sebastian Funk Date: Tue, 4 Jul 2023 14:08:11 +0100 Subject: [PATCH 2/2] initialise vectors to zero --- inst/stan/functions/infections.stan | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/inst/stan/functions/infections.stan b/inst/stan/functions/infections.stan index 7e5330081..c02d9e22e 100644 --- a/inst/stan/functions/infections.stan +++ b/inst/stan/functions/infections.stan @@ -27,9 +27,9 @@ vector generate_infections(vector oR, int uot, vector gt_rev_pmf, int t = ot + uot; vector[ot] R = oR; real exp_adj_Rt; - vector[t] infections = rep_vector(1e-5, t); - vector[ot] cum_infections = rep_vector(0, ot); - vector[ot] infectiousness = rep_vector(1e-5, ot); + vector[t] infections = rep_vector(0, t); + vector[ot] cum_infections; + vector[ot] infectiousness; // Initialise infections using daily growth infections[1] = exp(initial_infections[1]); if (uot > 1) { @@ -44,13 +44,13 @@ vector generate_infections(vector oR, int uot, vector gt_rev_pmf, } // iteratively update infections for (s in 1:ot) { - infectiousness[s] += update_infectiousness(infections, gt_rev_pmf, uot, s); + infectiousness[s] = update_infectiousness(infections, gt_rev_pmf, uot, s); if (pop && s > nht) { exp_adj_Rt = exp(-R[s] * infectiousness[s] / (pop - cum_infections[nht])); exp_adj_Rt = exp_adj_Rt > 1 ? 1 : exp_adj_Rt; infections[s + uot] = (pop - cum_infections[s]) * (1 - exp_adj_Rt); }else{ - infections[s + uot] += R[s] * infectiousness[s]; + infections[s + uot] = R[s] * infectiousness[s]; } if (pop && s < ot) { cum_infections[s + 1] = cum_infections[s] + infections[s + uot];