From 0d0be9f729106e1fca3459fefa8c7ee425c6de28 Mon Sep 17 00:00:00 2001 From: athowes Date: Thu, 7 Nov 2024 12:41:44 +0000 Subject: [PATCH 01/62] Add cohort model template --- R/cohort_model.R | 50 ++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 50 insertions(+) create mode 100644 R/cohort_model.R diff --git a/R/cohort_model.R b/R/cohort_model.R new file mode 100644 index 000000000..8510856de --- /dev/null +++ b/R/cohort_model.R @@ -0,0 +1,50 @@ +#' Prepare cohort model +#' +#' @param data A `data.frame` containing line list data +#' @family cohort_model +#' @export +as_direct_model <- function(data) { + UseMethod("as_cohort_model") +} + +assert_cohort_model_input <- function(data) { + # ... +} + +#' Prepare cohort model +#' +#' @param data A `data.frame` containing line list data +#' @rdname as_direct_model +#' @method as_direct_model data.frame +#' @family cohort_model +#' @autoglobal +#' @export +as_direct_model.data.frame <- function(data) { + assert_direct_model_input(data) + class(data) <- c("epidist_direct_model", class(data)) + data <- data |> + mutate(delay = .data$stime - .data$ptime) + epidist_validate(data) + return(data) +} + +#' Validate cohort model data +#' +#' @param data A `data.frame` containing line list data +#' @param ... ... +#' @method epidist_validate epidist_cohort_model +#' @family cohort_model +#' @export +epidist_validate.epidist_cohort_model <- function(data, ...) { + assert_true(is_cohort_model(data)) + assert_cohort_model_input(data) +} + +#' Check if data has the `epidist_cohort_model` class +#' +#' @param data A `data.frame` containing line list data +#' @family cohort_model +#' @export +is_direct_model <- function(data) { + inherits(data, "epidist_direct_model") +} From 810423ac22dc0d85f1029c486fb944c49319c4b1 Mon Sep 17 00:00:00 2001 From: athowes Date: Mon, 11 Nov 2024 09:44:42 +0000 Subject: [PATCH 02/62] Fix to previous commit (name as cohort_model) --- R/cohort_model.R | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/R/cohort_model.R b/R/cohort_model.R index 8510856de..d1d25d91e 100644 --- a/R/cohort_model.R +++ b/R/cohort_model.R @@ -3,7 +3,7 @@ #' @param data A `data.frame` containing line list data #' @family cohort_model #' @export -as_direct_model <- function(data) { +as_cohort_model <- function(data) { UseMethod("as_cohort_model") } @@ -14,14 +14,14 @@ assert_cohort_model_input <- function(data) { #' Prepare cohort model #' #' @param data A `data.frame` containing line list data -#' @rdname as_direct_model -#' @method as_direct_model data.frame +#' @rdname as_cohort_model +#' @method as_cohort_model data.frame #' @family cohort_model #' @autoglobal #' @export -as_direct_model.data.frame <- function(data) { - assert_direct_model_input(data) - class(data) <- c("epidist_direct_model", class(data)) +as_cohort_model.data.frame <- function(data) { + assert_cohort_model_input(data) + class(data) <- c("epidist_cohort_model", class(data)) data <- data |> mutate(delay = .data$stime - .data$ptime) epidist_validate(data) @@ -45,6 +45,6 @@ epidist_validate.epidist_cohort_model <- function(data, ...) { #' @param data A `data.frame` containing line list data #' @family cohort_model #' @export -is_direct_model <- function(data) { - inherits(data, "epidist_direct_model") +is_cohort_model <- function(data) { + inherits(data, "epidist_cohort_model") } From 8f4c34a05e2d048b54d98864b9c236486e0ecc01 Mon Sep 17 00:00:00 2001 From: athowes Date: Mon, 11 Nov 2024 10:10:44 +0000 Subject: [PATCH 03/62] Generate simulated cohort data --- inst/cohort-scratch.R | 31 +++++++++++++++++++++++++++++++ 1 file changed, 31 insertions(+) create mode 100644 inst/cohort-scratch.R diff --git a/inst/cohort-scratch.R b/inst/cohort-scratch.R new file mode 100644 index 000000000..04304acbf --- /dev/null +++ b/inst/cohort-scratch.R @@ -0,0 +1,31 @@ +library(dplyr) +library(ggplot2) +library(brms) +library(primarycensored) + +set.seed(101) + +obs_time <- 25 +sample_size <- 500 + +meanlog <- 1.8 +sdlog <- 0.5 + +sim_obs <- simulate_gillespie() |> + simulate_secondary( + dist = rlnorm, + meanlog = meanlog, + sdlog = sdlog + ) |> + observe_process() |> + filter_obs_by_obs_time(obs_time = obs_time) |> + dplyr::slice_sample(n = sample_size, replace = FALSE) + +# Create cohort version of data + +cohort_obs <- sim_obs |> + group_by(delay = delay_daily) |> + summarise(n = n()) + +ggplot(cohort_obs, aes(x = delay, y = n)) + + geom_col() From 8bcfbe735b955769dc345bc6fe3ea189de7546bb Mon Sep 17 00:00:00 2001 From: athowes Date: Mon, 11 Nov 2024 10:21:16 +0000 Subject: [PATCH 04/62] Add unweighted and weighted direct models --- inst/cohort-scratch.R | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/inst/cohort-scratch.R b/inst/cohort-scratch.R index 04304acbf..d11a009b4 100644 --- a/inst/cohort-scratch.R +++ b/inst/cohort-scratch.R @@ -2,6 +2,7 @@ library(dplyr) library(ggplot2) library(brms) library(primarycensored) +library(bayesplot) set.seed(101) @@ -29,3 +30,19 @@ cohort_obs <- sim_obs |> ggplot(cohort_obs, aes(x = delay, y = n)) + geom_col() + +fit_direct <- brms::brm( + formula = delay_daily ~ 1, + family = "lognormal", + data = sim_obs +) + +summary(fit_direct) + +fit_direct_weighted <- brms::brm( + formula = delay | weights(n) ~ 1, + family = "lognormal", + cohort_obs, +) + +summary(fit_direct_weighted) From e95316ab1c510ab8dc67fe477e76a843b3772362 Mon Sep 17 00:00:00 2001 From: athowes Date: Mon, 11 Nov 2024 10:56:41 +0000 Subject: [PATCH 05/62] Thinking about custom family for pcd function --- inst/cohort-scratch.R | 21 +++++++++++++++++++++ 1 file changed, 21 insertions(+) diff --git a/inst/cohort-scratch.R b/inst/cohort-scratch.R index d11a009b4..bd55bccf9 100644 --- a/inst/cohort-scratch.R +++ b/inst/cohort-scratch.R @@ -46,3 +46,24 @@ fit_direct_weighted <- brms::brm( ) summary(fit_direct_weighted) + +lognormal <- brms::lognormal() + +primarycensored_lognormal_uniform_lcdf <- brms::custom_family( + "primarycensored_lognormal_uniform_lcdf", + dpars = lognormal$dpar, + links = c(lognormal$link, lognormal$link_sigma), + type = lognormal$type, + loop = FALSE +) + +primarycensored_lognormal_uniform_lcdf_file <- file.path( + tempdir(), "primarycensored_lognormal_uniform_lcdf.stan" +) + +pcd_load_stan_functions( + "primarycensored_lognormal_uniform_lcdf", + write_to_file = TRUE, + output_file = primarycensored_lognormal_uniform_lcdf_file, + wrap_in_block = TRUE +) From 2165646c8fa355e07b13e6b31779990e5fb5d19a Mon Sep 17 00:00:00 2001 From: athowes Date: Tue, 12 Nov 2024 10:06:55 +0000 Subject: [PATCH 06/62] functions component of stanvars --- inst/cohort-scratch.R | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/inst/cohort-scratch.R b/inst/cohort-scratch.R index bd55bccf9..1b80f9e4c 100644 --- a/inst/cohort-scratch.R +++ b/inst/cohort-scratch.R @@ -61,9 +61,7 @@ primarycensored_lognormal_uniform_lcdf_file <- file.path( tempdir(), "primarycensored_lognormal_uniform_lcdf.stan" ) -pcd_load_stan_functions( - "primarycensored_lognormal_uniform_lcdf", - write_to_file = TRUE, - output_file = primarycensored_lognormal_uniform_lcdf_file, - wrap_in_block = TRUE +stanvars_functions <- brms::stanvar( + block = "functions", + scode = pcd_load_stan_functions("primarycensored_lognormal_uniform_lcdf") ) From 602f01589704b9e303eb03d9275e7f59863b8687 Mon Sep 17 00:00:00 2001 From: athowes Date: Tue, 12 Nov 2024 10:40:17 +0000 Subject: [PATCH 07/62] Add transformed parameters for cohort model --- inst/cohort-scratch.R | 9 +++++++++ inst/stan/cohort_model/tparameters.stan | 3 +++ 2 files changed, 12 insertions(+) create mode 100644 inst/stan/cohort_model/tparameters.stan diff --git a/inst/cohort-scratch.R b/inst/cohort-scratch.R index 1b80f9e4c..3dc530c2d 100644 --- a/inst/cohort-scratch.R +++ b/inst/cohort-scratch.R @@ -65,3 +65,12 @@ stanvars_functions <- brms::stanvar( block = "functions", scode = pcd_load_stan_functions("primarycensored_lognormal_uniform_lcdf") ) + +stanvars_tparameters <- brms::stanvar( + block = "tparameters", + scode = .stan_chunk("cohort_model/tparameters.stan") +) + +stanvars_all <- stanvars_functions + stanvars_tparameters + +stanvars_all diff --git a/inst/stan/cohort_model/tparameters.stan b/inst/stan/cohort_model/tparameters.stan new file mode 100644 index 000000000..f0f102c40 --- /dev/null +++ b/inst/stan/cohort_model/tparameters.stan @@ -0,0 +1,3 @@ +vector[2] params; +params[1] = mu; +params[2] = sigma; From 8d4c7379029a2678e96f8fa15f08bf6806eea70a Mon Sep 17 00:00:00 2001 From: athowes Date: Tue, 12 Nov 2024 11:24:07 +0000 Subject: [PATCH 08/62] Progress on implementing PCD model --- inst/cohort-scratch.R | 37 ++++++++++++++++++++++++++++--- inst/stan/cohort_model/data.stan | 1 + inst/stan/cohort_model/tdata.stan | 1 + 3 files changed, 36 insertions(+), 3 deletions(-) create mode 100644 inst/stan/cohort_model/data.stan create mode 100644 inst/stan/cohort_model/tdata.stan diff --git a/inst/cohort-scratch.R b/inst/cohort-scratch.R index 3dc530c2d..0d2c9aa3a 100644 --- a/inst/cohort-scratch.R +++ b/inst/cohort-scratch.R @@ -54,13 +54,18 @@ primarycensored_lognormal_uniform_lcdf <- brms::custom_family( dpars = lognormal$dpar, links = c(lognormal$link, lognormal$link_sigma), type = lognormal$type, - loop = FALSE + loop = TRUE, + vars = "pwindow" ) primarycensored_lognormal_uniform_lcdf_file <- file.path( tempdir(), "primarycensored_lognormal_uniform_lcdf.stan" ) +data <- cohort_obs |> + select(d = delay, n = n) |> + mutate(pwindow = 1) + stanvars_functions <- brms::stanvar( block = "functions", scode = pcd_load_stan_functions("primarycensored_lognormal_uniform_lcdf") @@ -71,6 +76,32 @@ stanvars_tparameters <- brms::stanvar( scode = .stan_chunk("cohort_model/tparameters.stan") ) -stanvars_all <- stanvars_functions + stanvars_tparameters +stanvars_tdata <- brms::stanvar( + block = "tdata", + scode = .stan_chunk("cohort_model/tdata.stan") +) + +pwindow <- data$pwindow + +stanvars_data <- brms::stanvar( + x = pwindow, + block = "data", + scode = .stan_chunk("cohort_model/data.stan") +) -stanvars_all +stanvars_all <- stanvars_functions + stanvars_tparameters + stanvars_tdata + + stanvars_data + +brms::make_stancode( + formula = d | weights(n) ~ 1, + family = primarycensored_lognormal_uniform_lcdf, + data = data, + stanvars = stanvars_all, +) + +fit_pcd <- brms::brm( + formula = d | weights(n) ~ 1, + family = primarycensored_lognormal_uniform_lcdf, + data = data, + stanvars = stanvars_all, +) diff --git a/inst/stan/cohort_model/data.stan b/inst/stan/cohort_model/data.stan new file mode 100644 index 000000000..00a3e0b49 --- /dev/null +++ b/inst/stan/cohort_model/data.stan @@ -0,0 +1 @@ +vector[N] pwindow; diff --git a/inst/stan/cohort_model/tdata.stan b/inst/stan/cohort_model/tdata.stan new file mode 100644 index 000000000..832112772 --- /dev/null +++ b/inst/stan/cohort_model/tdata.stan @@ -0,0 +1 @@ +real q = fmax(Y - pwindow, 0); From 5590a0f5991e8bf4ea309ed05f73552b6e0dd66c Mon Sep 17 00:00:00 2001 From: athowes Date: Tue, 12 Nov 2024 11:36:41 +0000 Subject: [PATCH 09/62] Set q as vector --- inst/stan/cohort_model/tdata.stan | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inst/stan/cohort_model/tdata.stan b/inst/stan/cohort_model/tdata.stan index 832112772..0e3a441b1 100644 --- a/inst/stan/cohort_model/tdata.stan +++ b/inst/stan/cohort_model/tdata.stan @@ -1 +1 @@ -real q = fmax(Y - pwindow, 0); +vector[N] q = fmax(Y - pwindow, 0); From d005cbaa6811b6d1d8c80478a8bd5a32cfeb1654 Mon Sep 17 00:00:00 2001 From: athowes Date: Wed, 13 Nov 2024 10:06:17 +0000 Subject: [PATCH 10/62] Get rid of params input --- inst/cohort-scratch.R | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/inst/cohort-scratch.R b/inst/cohort-scratch.R index 0d2c9aa3a..ca8e1502b 100644 --- a/inst/cohort-scratch.R +++ b/inst/cohort-scratch.R @@ -66,15 +66,19 @@ data <- cohort_obs |> select(d = delay, n = n) |> mutate(pwindow = 1) +pcd_function <- pcd_load_stan_functions("primarycensored_lognormal_uniform_lcdf") +pcd_function <- sub(pattern = "array\\[\\] real params", "real mu, real sigma", pcd_function) +pcd_function <- gsub("\\s*real mu = params\\[1\\];\\n\\s*real sigma = params\\[2\\];\\n", "", pcd_function) + stanvars_functions <- brms::stanvar( block = "functions", - scode = pcd_load_stan_functions("primarycensored_lognormal_uniform_lcdf") + scode = pcd_function ) -stanvars_tparameters <- brms::stanvar( - block = "tparameters", - scode = .stan_chunk("cohort_model/tparameters.stan") -) +# stanvars_tparameters <- brms::stanvar( +# block = "tparameters", +# scode = .stan_chunk("cohort_model/tparameters.stan") +# ) stanvars_tdata <- brms::stanvar( block = "tdata", @@ -89,8 +93,7 @@ stanvars_data <- brms::stanvar( scode = .stan_chunk("cohort_model/data.stan") ) -stanvars_all <- stanvars_functions + stanvars_tparameters + stanvars_tdata + - stanvars_data +stanvars_all <- stanvars_functions + stanvars_tdata + stanvars_data brms::make_stancode( formula = d | weights(n) ~ 1, @@ -105,3 +108,5 @@ fit_pcd <- brms::brm( data = data, stanvars = stanvars_all, ) + + From 31c37368f38df3ccb85d657b08874b5faca6cef7 Mon Sep 17 00:00:00 2001 From: athowes Date: Wed, 13 Nov 2024 11:08:07 +0000 Subject: [PATCH 11/62] This would work, apart from it's the CDF. For the PMF need to import many primarycensored functions.. --- inst/cohort-scratch.R | 25 +++++++++++++------------ 1 file changed, 13 insertions(+), 12 deletions(-) diff --git a/inst/cohort-scratch.R b/inst/cohort-scratch.R index ca8e1502b..f392dcf51 100644 --- a/inst/cohort-scratch.R +++ b/inst/cohort-scratch.R @@ -55,7 +55,7 @@ primarycensored_lognormal_uniform_lcdf <- brms::custom_family( links = c(lognormal$link, lognormal$link_sigma), type = lognormal$type, loop = TRUE, - vars = "pwindow" + vars = c("pwindow", "vreal1[n]") ) primarycensored_lognormal_uniform_lcdf_file <- file.path( @@ -64,8 +64,11 @@ primarycensored_lognormal_uniform_lcdf_file <- file.path( data <- cohort_obs |> select(d = delay, n = n) |> - mutate(pwindow = 1) - + mutate( + pwindow = 1, + q = pmax(d - pwindow, 0) + ) + pcd_function <- pcd_load_stan_functions("primarycensored_lognormal_uniform_lcdf") pcd_function <- sub(pattern = "array\\[\\] real params", "real mu, real sigma", pcd_function) pcd_function <- gsub("\\s*real mu = params\\[1\\];\\n\\s*real sigma = params\\[2\\];\\n", "", pcd_function) @@ -80,10 +83,10 @@ stanvars_functions <- brms::stanvar( # scode = .stan_chunk("cohort_model/tparameters.stan") # ) -stanvars_tdata <- brms::stanvar( - block = "tdata", - scode = .stan_chunk("cohort_model/tdata.stan") -) +# stanvars_tdata <- brms::stanvar( +# block = "tdata", +# scode = .stan_chunk("cohort_model/tdata.stan") +# ) pwindow <- data$pwindow @@ -93,20 +96,18 @@ stanvars_data <- brms::stanvar( scode = .stan_chunk("cohort_model/data.stan") ) -stanvars_all <- stanvars_functions + stanvars_tdata + stanvars_data +stanvars_all <- stanvars_functions + stanvars_data brms::make_stancode( - formula = d | weights(n) ~ 1, + formula = d | weights(n) + vreal(q) ~ 1, family = primarycensored_lognormal_uniform_lcdf, data = data, stanvars = stanvars_all, ) fit_pcd <- brms::brm( - formula = d | weights(n) ~ 1, + formula = d | weights(n), vreal(q) ~ 1, family = primarycensored_lognormal_uniform_lcdf, data = data, stanvars = stanvars_all, ) - - From b4b71a0bd46e8d1fe913e252673ceb476efaffa2 Mon Sep 17 00:00:00 2001 From: athowes Date: Wed, 13 Nov 2024 11:40:23 +0000 Subject: [PATCH 12/62] Almost working with "import all functions" strategy --- inst/cohort-scratch.R | 29 +- .../cohort_model/primarycensored-edit.stan | 844 ++++++++++++++++++ 2 files changed, 856 insertions(+), 17 deletions(-) create mode 100644 inst/stan/cohort_model/primarycensored-edit.stan diff --git a/inst/cohort-scratch.R b/inst/cohort-scratch.R index f392dcf51..61d3b5b27 100644 --- a/inst/cohort-scratch.R +++ b/inst/cohort-scratch.R @@ -49,17 +49,13 @@ summary(fit_direct_weighted) lognormal <- brms::lognormal() -primarycensored_lognormal_uniform_lcdf <- brms::custom_family( - "primarycensored_lognormal_uniform_lcdf", +primarycensored_lognormal_uniform_lpmf <- brms::custom_family( + "primarycensored_lognormal_uniform", dpars = lognormal$dpar, links = c(lognormal$link, lognormal$link_sigma), - type = lognormal$type, + type = "int", loop = TRUE, - vars = c("pwindow", "vreal1[n]") -) - -primarycensored_lognormal_uniform_lcdf_file <- file.path( - tempdir(), "primarycensored_lognormal_uniform_lcdf.stan" + vars = c("vreal1[n]", "pwindow[n]") ) data <- cohort_obs |> @@ -68,14 +64,10 @@ data <- cohort_obs |> pwindow = 1, q = pmax(d - pwindow, 0) ) - -pcd_function <- pcd_load_stan_functions("primarycensored_lognormal_uniform_lcdf") -pcd_function <- sub(pattern = "array\\[\\] real params", "real mu, real sigma", pcd_function) -pcd_function <- gsub("\\s*real mu = params\\[1\\];\\n\\s*real sigma = params\\[2\\];\\n", "", pcd_function) stanvars_functions <- brms::stanvar( block = "functions", - scode = pcd_function + scode = .stan_chunk("cohort_model/primarycensored-edit.stan") ) # stanvars_tparameters <- brms::stanvar( @@ -98,16 +90,19 @@ stanvars_data <- brms::stanvar( stanvars_all <- stanvars_functions + stanvars_data -brms::make_stancode( +stancode <- brms::make_stancode( formula = d | weights(n) + vreal(q) ~ 1, - family = primarycensored_lognormal_uniform_lcdf, + family = primarycensored_lognormal_uniform_lpmf, data = data, stanvars = stanvars_all, ) +model <- rstan::stan_model(model_code = stancode) + fit_pcd <- brms::brm( - formula = d | weights(n), vreal(q) ~ 1, - family = primarycensored_lognormal_uniform_lcdf, + formula = d | weights(n) + vreal(q) ~ 1, + family = primarycensored_lognormal_uniform_lpmf, data = data, stanvars = stanvars_all, + backend = "cmdstanr" ) diff --git a/inst/stan/cohort_model/primarycensored-edit.stan b/inst/stan/cohort_model/primarycensored-edit.stan new file mode 100644 index 000000000..0b2899b70 --- /dev/null +++ b/inst/stan/cohort_model/primarycensored-edit.stan @@ -0,0 +1,844 @@ +// Copied from https://github.com/epinowcast/primarycensored/blob/main/inst/stan/functions/ +// Edited to work with temporary brms function + +real primarycensored_lognormal_uniform_lpmf(data int d, real mu, real sigma, real q, data real pwindow) { + array[2] real params = {mu, sigma}; + array[0] real primary_params; + int d_upper = d + 1; + return primarycensored_lpmf(d | 1, params, pwindow, d_upper, positive_infinity(), 1, primary_params); +} + +/** + * Primary event censored distribution functions + */ + +/** + * Compute the primary event censored CDF for a single delay + * + * @param d Delay + * @param dist_id Distribution identifier + * @param params Array of distribution parameters + * @param pwindow Primary event window + * @param D Maximum delay (truncation point) + * @param primary_id Primary distribution identifier + * @param primary_params Primary distribution parameters + * + * @return Primary event censored CDF, normalized by D if finite (truncation adjustment) + */ +real primarycensored_cdf(data real d, int dist_id, array[] real params, + data real pwindow, data real D, + int primary_id, + array[] real primary_params) { + real result; + if (d <= 0) { + return 0; + } + + if (d >= D) { + return 1; + } + + // Check if an analytical solution exists + if (check_for_analytical(dist_id, primary_id)) { + // Use analytical solution + result = primarycensored_analytical_cdf( + d | dist_id, params, pwindow, D, primary_id, primary_params + ); + } else { + // Use numerical integration for other cases + real lower_bound = max({d - pwindow, 1e-6}); + array[size(params) + size(primary_params)] real theta = append_array(params, primary_params); + array[4] int ids = {dist_id, primary_id, size(params), size(primary_params)}; + + vector[1] y0 = rep_vector(0.0, 1); + result = ode_rk45(primarycensored_ode, y0, lower_bound, {d}, theta, {d, pwindow}, ids)[1, 1]; + + if (!is_inf(D)) { + real log_cdf_D = primarycensored_lcdf( + D | dist_id, params, pwindow, positive_infinity(), primary_id,primary_params + ); + result = exp(log(result) - log_cdf_D); + } + } + + return result; +} + +/** + * Compute the primary event censored log CDF for a single delay + * + * @param d Delay + * @param dist_id Distribution identifier + * @param params Array of distribution parameters + * @param pwindow Primary event window + * @param D Maximum delay (truncation point) + * @param primary_id Primary distribution identifier + * @param primary_params Primary distribution parameters + * + * @return Primary event censored log CDF, normalized by D if finite (truncation adjustment) + * + * @code + * // Example: Weibull delay distribution with uniform primary distribution + * real d = 3.0; + * int dist_id = 5; // Weibull + * array[2] real params = {2.0, 1.5}; // shape and scale + * real pwindow = 1.0; + * real D = positive_infinity(); + * int primary_id = 1; // Uniform + * array[0] real primary_params = {}; + * real log_cdf = primarycensored_lcdf( + * d, dist_id, params, pwindow, D, primary_id, primary_params + * ); + * @endcode + */ +real primarycensored_lcdf(data real d, int dist_id, array[] real params, + data real pwindow, data real D, + int primary_id, + array[] real primary_params) { + real result; + + if (d <= 0) { + return negative_infinity(); + } + + if (d >= D) { + return 0; + } + + // Check if an analytical solution exists + if (check_for_analytical(dist_id, primary_id)) { + result = primarycensored_analytical_lcdf( + d | dist_id, params, pwindow, positive_infinity(), primary_id, primary_params + ); + } else { + // Use numerical integration + result = log(primarycensored_cdf( + d | dist_id, params, pwindow, positive_infinity(), primary_id, primary_params + )); + } + + // Handle truncation + if (!is_inf(D)) { + real log_cdf_D = primarycensored_lcdf( + D | dist_id, params, pwindow, positive_infinity(), primary_id, primary_params + ); + result = result - log_cdf_D; + } + + return result; +} + +/** + * Compute the primary event censored log PMF for a single delay + * + * @param d Delay (integer) + * @param dist_id Distribution identifier + * @param params Array of distribution parameters + * @param pwindow Primary event window + * @param d_upper Upper bound for the delay interval + * @param D Maximum delay (truncation point) + * @param primary_id Primary distribution identifier + * @param primary_params Primary distribution parameters + * + * @return Primary event censored log PMF, normalized by D if finite (truncation adjustment) + * + * @code + * // Example: Weibull delay distribution with uniform primary distribution + * int d = 3; + * int dist_id = 5; // Weibull + * array[2] real params = {2.0, 1.5}; // shape and scale + * real pwindow = 1.0; + * real d_upper = 4.0; + * real D = positive_infinity(); + * int primary_id = 1; // Uniform + * array[0] real primary_params = {}; + * real log_pmf = primarycensored_lpmf( + * d, dist_id, params, pwindow, d_upper, D, primary_id, primary_params + * ); + * @endcode + */ +real primarycensored_lpmf(data int d, int dist_id, array[] real params, + data real pwindow, data real d_upper, + data real D, int primary_id, + array[] real primary_params) { + if (d_upper > D) { + reject("Upper truncation point is greater than D. It is ", d_upper, + " and D is ", D, ". Resolve this by increasing D to be greater or equal to d + swindow or decreasing swindow."); + } + if (d_upper <= d) { + reject("Upper truncation point is less than or equal to d. It is ", d_upper, + " and d is ", d, ". Resolve this by increasing d to be less than d_upper."); + } + real log_cdf_upper = primarycensored_lcdf( + d_upper | dist_id, params, pwindow, positive_infinity(), primary_id, primary_params + ); + real log_cdf_lower = primarycensored_lcdf( + d | dist_id, params, pwindow, positive_infinity(), primary_id, primary_params + ); + if (!is_inf(D)) { + real log_cdf_D; + + if (d_upper == D) { + log_cdf_D = log_cdf_upper; + } else { + log_cdf_D = primarycensored_lcdf( + D | dist_id, params, pwindow, positive_infinity(), primary_id, primary_params + ); + } + return log_diff_exp(log_cdf_upper, log_cdf_lower) - log_cdf_D; + } else { + return log_diff_exp(log_cdf_upper, log_cdf_lower); + } +} + +/** + * Compute the primary event censored PMF for a single delay + * + * @param d Delay (integer) + * @param dist_id Distribution identifier + * @param params Array of distribution parameters + * @param pwindow Primary event window + * @param d_upper Upper bound for the delay interval + * @param D Maximum delay (truncation point) + * @param primary_id Primary distribution identifier + * @param primary_params Primary distribution parameters + * + * @return Primary event censored PMF, normalized by D if finite (truncation adjustment) + * + * @code + * // Example: Weibull delay distribution with uniform primary distribution + * int d = 3; + * real d = 3.0; + * int dist_id = 5; // Weibull + * array[2] real params = {2.0, 1.5}; // shape and scale + * real pwindow = 1.0; + * real swindow = 0.1; + * real D = positive_infinity(); + * int primary_id = 1; // Uniform + * array[0] real primary_params = {}; + * real pmf = primarycensored_pmf(d, dist_id, params, pwindow, swindow, D, primary_id, primary_params); + * @endcode + */ +real primarycensored_pmf(data int d, int dist_id, array[] real params, + data real pwindow, data real d_upper, + data real D, int primary_id, + array[] real primary_params) { + return exp( + primarycensored_lpmf( + d | dist_id, params, pwindow, d_upper, D, primary_id, primary_params + ) + ); +} + +/** + * Compute the primary event censored log PMF for integer delays up to max_delay + * + * @param max_delay Maximum delay to compute PMF for + * @param D Maximum delay (truncation point), must be at least max_delay + 1 + * @param dist_id Distribution identifier + * @param params Array of distribution parameters + * @param pwindow Primary event window + * @param primary_id Primary distribution identifier + * @param primary_params Primary distribution parameters + * + * @return Vector of primary event censored log PMFs for delays \[0, 1\] to + * \[max_delay, max_delay + 1\]. + * + * This function differs from primarycensored_lpmf in that it: + * 1. Computes PMFs for all integer delays from \[0, 1\] to \[max_delay, + * max_delay + 1\] in one call. + * 2. Assumes integer delays (swindow = 1) + * 3. Is more computationally efficient for multiple delay calculation as it + * reduces the number of integration calls. + * + * @code + * // Example: Weibull delay distribution with uniform primary distribution + * int max_delay = 10; + * real D = 15.0; + * int dist_id = 5; // Weibull + * array[2] real params = {2.0, 1.5}; // shape and scale + * real pwindow = 7.0; + * int primary_id = 1; // Uniform + * array[0] real primary_params = {}; + + * vector[max_delay] log_pmf = + * primarycensored_sone_lpmf_vectorized( + * max_delay, D, dist_id, params, pwindow, primary_id, + * primary_params + * ); + * @endcode + */ +vector primarycensored_sone_lpmf_vectorized( + int max_delay, data real D, int dist_id, + array[] real params, data real pwindow, + int primary_id, array[] real primary_params +) { + + int upper_interval = max_delay + 1; + vector[upper_interval] log_pmfs; + vector[upper_interval] log_cdfs; + real log_normalizer; + + // Check if D is at least max_delay + 1 + if (D < upper_interval) { + reject("D must be at least max_delay + 1"); + } + + // Compute log CDFs + for (d in 1:upper_interval) { + log_cdfs[d] = primarycensored_lcdf( + d | dist_id, params, pwindow, positive_infinity(), primary_id, + primary_params + ); + } + + // Compute log normalizer using upper_interval + if (D > upper_interval) { + if (is_inf(D)) { + log_normalizer = 0; // No normalization needed for infinite D + } else { + log_normalizer = primarycensored_lcdf( + D | dist_id, params, pwindow, positive_infinity(), + primary_id, primary_params + ); + } + } else { + log_normalizer = log_cdfs[upper_interval]; + } + + // Compute log PMFs + log_pmfs[1] = log_cdfs[1] - log_normalizer; + for (d in 2:upper_interval) { + log_pmfs[d] = log_diff_exp(log_cdfs[d], log_cdfs[d-1]) - log_normalizer; + } + + return log_pmfs; +} + +/** + * Compute the primary event censored PMF for integer delays up to max_delay + * + * @param max_delay Maximum delay to compute PMF for + * @param D Maximum delay (truncation point), must be at least max_delay + 1 + * @param dist_id Distribution identifier + * @param params Array of distribution parameters + * @param pwindow Primary event window + * @param primary_id Primary distribution identifier + * @param primary_params Primary distribution parameters + * + * @return Vector of primary event censored PMFs for integer delays 1 to + * max_delay + * + * This function differs from primarycensored_pmf in that it: + * 1. Computes PMFs for all integer delays from \[0, 1\] to \[max_delay, + * max_delay + 1\] in one call. + * 2. Assumes integer delays (swindow = 1) + * 3. Is more computationally efficient for multiple delay calculations + * + * @code + * // Example: Weibull delay distribution with uniform primary distribution + * int max_delay = 10; + * real D = 15.0; + * int dist_id = 5; // Weibull + * array[2] real params = {2.0, 1.5}; // shape and scale + * real pwindow = 7.0; + * int primary_id = 1; // Uniform + * array[0] real primary_params = {}; + * vector[max_delay] pmf = + * primarycensored_sone_lpmf_vectorized( + * max_delay, D, dist_id, params, pwindow, primary_id, primary_params + * ); + * @endcode + */ +vector primarycensored_sone_pmf_vectorized( + int max_delay, data real D, int dist_id, + array[] real params, data real pwindow, + int primary_id, + array[] real primary_params +) { + return exp( + primarycensored_sone_lpmf_vectorized( + max_delay, D, dist_id, params, pwindow, primary_id, primary_params + ) + ); +} + + +/** + * Check if an analytical solution exists for the given distribution combination + * + * @param dist_id Distribution identifier for the delay distribution + * @param primary_id Distribution identifier for the primary distribution + * + * @return 1 if an analytical solution exists, 0 otherwise + */ +int check_for_analytical(int dist_id, int primary_id) { + if (dist_id == 2 && primary_id == 1) return 1; // Gamma delay with Uniform primary + if (dist_id == 1 && primary_id == 1) return 1; // Lognormal delay with Uniform primary + if (dist_id == 3 && primary_id == 1) return 1; // Weibull delay with Uniform primary + return 0; // No analytical solution for other combinations +} + +/** + * Compute the primary event censored log CDF analytically for Gamma delay with Uniform primary + * + * @param d Delay time + * @param q Lower bound of integration (max(d - pwindow, 0)) + * @param params Array of Gamma distribution parameters [shape, rate] + * @param pwindow Primary event window + * + * @return Log of the primary event censored CDF for Gamma delay with Uniform + * primary + */ +real primarycensored_gamma_uniform_lcdf(data real d, real q, array[] real params, data real pwindow) { + real shape = params[1]; + real rate = params[2]; + real shape_1 = shape + 1; + real log_window = log(pwindow); + + real log_F_T = gamma_lcdf(d | shape, rate); + real log_F_T_kp1 = gamma_lcdf(d | shape_1, rate); + + real log_delta_F_T_kp1; + real log_delta_F_T_k; + real log_F_Splus; + + if (q != 0) { + real log_F_T_q = gamma_lcdf(q | shape, rate); + real log_F_T_q_kp1 = gamma_lcdf(q | shape_1, rate); + + // Ensure that the first argument is greater than the second + log_delta_F_T_kp1 = log_diff_exp(log_F_T_kp1, log_F_T_q_kp1); + log_delta_F_T_k = log_diff_exp(log_F_T, log_F_T_q); + + log_F_Splus = log_diff_exp( + log_F_T, + log_diff_exp( + log(shape * inv(rate)) + log_delta_F_T_kp1, + log(d - pwindow) + log_delta_F_T_k + ) - log_window + ); + } else { + log_delta_F_T_kp1 = log_F_T_kp1; + log_delta_F_T_k = log_F_T; + + log_F_Splus = log_diff_exp( + log_F_T, + log_sum_exp( + log(shape * inv(rate)) + log_delta_F_T_kp1, + log(pwindow - d) + log_delta_F_T_k + ) - log_window + ); + } + + return log_F_Splus; +} + +/** + * Compute the primary event censored log CDF analytically for Lognormal delay with Uniform primary + * + * @param d Delay time + * @param q Lower bound of integration (max(d - pwindow, 0)) + * @param params Array of Lognormal distribution parameters [mu, sigma] + * @param pwindow Primary event window + * + * @return Log of the primary event censored CDF for Lognormal delay with + * Uniform primary + */ +real primarycensored_lognormal_uniform_lcdf(data real d, real q, array[] real params, data real pwindow) { + real mu = params[1]; + real sigma = params[2]; + real mu_sigma2 = mu + square(sigma); + real log_window = log(pwindow); + + real log_F_T = lognormal_lcdf(d | mu, sigma); + real log_F_T_mu_sigma2 = lognormal_lcdf(d | mu_sigma2, sigma); + + real log_delta_F_T_mu_sigma; + real log_delta_F_T; + real log_F_Splus; + + if (q != 0) { + real log_F_T_q = lognormal_lcdf(q | mu, sigma); + real log_F_T_q_mu_sigma2 = lognormal_lcdf(q | mu_sigma2, sigma); + + // Ensure that the first argument is greater than the second + log_delta_F_T_mu_sigma = log_diff_exp( + log_F_T_mu_sigma2, log_F_T_q_mu_sigma2 + ); + log_delta_F_T = log_diff_exp(log_F_T, log_F_T_q); + + log_F_Splus = log_diff_exp( + log_F_T, + log_diff_exp( + (mu + 0.5 * square(sigma)) + log_delta_F_T_mu_sigma, + log(d - pwindow) + log_delta_F_T + ) - log_window + ); + } else { + log_delta_F_T_mu_sigma = log_F_T_mu_sigma2; + log_delta_F_T = log_F_T; + + log_F_Splus = log_diff_exp( + log_F_T, + log_sum_exp( + (mu + 0.5 * square(sigma)) + log_delta_F_T_mu_sigma, + log(pwindow - d) + log_delta_F_T + ) - log_window + ); + } + + return log_F_Splus; +} + +/** + * Compute the log of the lower incomplete gamma function + * + * This function is used in the analytical solution for the primary censored + * Weibull distribution with uniform primary censoring. It corresponds to the + * g(t; λ, k) function described in the analytic solutions document. + * + * @param t Upper bound of integration + * @param shape Shape parameter (k) of the Weibull distribution + * @param scale Scale parameter (λ) of the Weibull distribution + * + * @return Log of g(t; λ, k) = γ(1 + 1/k, (t/λ)^k) + */ +real log_weibull_g(real t, real shape, real scale) { + real x = pow(t * inv(scale), shape); + real a = 1 + inv(shape); + return log(gamma_p(a, x)) + lgamma(a); +} + +/** + * Compute the primary event censored log CDF analytically for Weibull delay with Uniform primary + * + * @param d Delay time + * @param q Lower bound of integration (max(d - pwindow, 0)) + * @param params Array of Weibull distribution parameters [shape, scale] + * @param pwindow Primary event window + * + * @return Log of the primary event censored CDF for Weibull delay with + * Uniform primary + */ +real primarycensored_weibull_uniform_lcdf(data real d, real q, array[] real params, data real pwindow) { + real shape = params[1]; + real scale = params[2]; + real log_window = log(pwindow); + + real log_F_T = weibull_lcdf(d | shape, scale); + + real log_delta_g; + real log_delta_F_T; + real log_F_Splus; + + if (q != 0) { + real log_F_T_q = weibull_lcdf(q | shape, scale); + + log_delta_g = log_diff_exp( + log_weibull_g(d, shape, scale), + log_weibull_g(q, shape, scale) + ); + log_delta_F_T = log_diff_exp(log_F_T, log_F_T_q); + + log_F_Splus = log_diff_exp( + log_F_T, + log_diff_exp( + log(scale) + log_delta_g, + log(d - pwindow) + log_delta_F_T + ) - log_window + ); + } else { + log_delta_g = log_weibull_g(d, shape, scale); + log_delta_F_T = log_F_T; + + log_F_Splus = log_diff_exp( + log_F_T, + log_sum_exp( + log(scale) + log_delta_g, + log(pwindow - d) + log_delta_F_T + ) - log_window + ); + } + + return log_F_Splus; +} + +/** + * Compute the primary event censored log CDF analytically for a single delay + * + * @param d Delay + * @param dist_id Distribution identifier + * @param params Array of distribution parameters + * @param pwindow Primary event window + * @param D Maximum delay (truncation point) + * @param primary_id Primary distribution identifier + * @param primary_params Primary distribution parameters + * + * @return Primary event censored log CDF, normalized by D if finite (truncation adjustment) + */ +real primarycensored_analytical_lcdf(data real d, int dist_id, + array[] real params, + data real pwindow, data real D, + int primary_id, + array[] real primary_params) { + real result; + real log_cdf_D; + + if (d <= 0) return negative_infinity(); + if (d >= D) return 0; + + real q = max({d - pwindow, 0}); + + if (dist_id == 2 && primary_id == 1) { + // Gamma delay with Uniform primary + result = primarycensored_gamma_uniform_lcdf(d | q, params, pwindow); + } else if (dist_id == 1 && primary_id == 1) { + // Lognormal delay with Uniform primary + result = primarycensored_lognormal_uniform_lcdf(d | q, params, pwindow); + } else if (dist_id == 3 && primary_id == 1) { + // Weibull delay with Uniform primary + result = primarycensored_weibull_uniform_lcdf(d | q, params, pwindow); + } else { + // No analytical solution available + return negative_infinity(); + } + + if (!is_inf(D)) { + log_cdf_D = primarycensored_lcdf( + D | dist_id, params, pwindow, positive_infinity(), + primary_id, primary_params + ); + result = result - log_cdf_D; + } + + return result; +} + +/** + * Compute the primary event censored CDF analytically for a single delay + * + * @param d Delay + * @param dist_id Distribution identifier + * @param params Array of distribution parameters + * @param pwindow Primary event window + * @param D Maximum delay (truncation point) + * @param primary_id Primary distribution identifier + * @param primary_params Primary distribution parameters + * + * @return Primary event censored CDF, normalized by D if finite (truncation adjustment) + */ +real primarycensored_analytical_cdf(data real d, int dist_id, + array[] real params, + data real pwindow, data real D, + int primary_id, + array[] real primary_params) { + return exp(primarycensored_analytical_lcdf(d | dist_id, params, pwindow, D, primary_id, primary_params)); +} + +/** + * Compute the log CDF of the delay distribution + * + * @param delay Time delay + * @param params Distribution parameters + * @param dist_id Distribution identifier + * 1: Lognormal, 2: Gamma, 3: Normal, 4: Exponential, 5: Weibull, + * 6: Beta, 7: Cauchy, 8: Chi-square, 9: Inverse Chi-square, + * 10: Double Exponential, 11: Inverse Gamma, 12: Logistic, + * 13: Pareto, 14: Scaled Inverse Chi-square, 15: Student's t, + * 16: Uniform, 17: von Mises + * + * @return Log CDF of the delay distribution + * + * @code + * // Example: Lognormal distribution + * real delay = 5.0; + * array[2] real params = {0.0, 1.0}; // mean and standard deviation on log scale + * int dist_id = 1; // Lognormal + * real log_cdf = dist_lcdf(delay, params, dist_id); + * @endcode + */ +real dist_lcdf(real delay, array[] real params, int dist_id) { + if (delay <= 0) return negative_infinity(); + + // Use if-else statements to handle different distribution types + if (dist_id == 1) return lognormal_lcdf(delay | params[1], params[2]); + else if (dist_id == 2) return gamma_lcdf(delay | params[1], params[2]); + else if (dist_id == 3) return normal_lcdf(delay | params[1], params[2]); + else if (dist_id == 4) return exponential_lcdf(delay | params[1]); + else if (dist_id == 5) return weibull_lcdf(delay | params[1], params[2]); + else if (dist_id == 6) return beta_lcdf(delay | params[1], params[2]); + else if (dist_id == 7) return cauchy_lcdf(delay | params[1], params[2]); + else if (dist_id == 8) return chi_square_lcdf(delay | params[1]); + else if (dist_id == 9) return inv_chi_square_lcdf(delay | params[1]); + else if (dist_id == 10) return double_exponential_lcdf(delay | params[1], params[2]); + else if (dist_id == 11) return inv_gamma_lcdf(delay | params[1], params[2]); + else if (dist_id == 12) return logistic_lcdf(delay | params[1], params[2]); + else if (dist_id == 13) return pareto_lcdf(delay | params[1], params[2]); + else if (dist_id == 14) return scaled_inv_chi_square_lcdf(delay | params[1], params[2]); + else if (dist_id == 15) return student_t_lcdf(delay | params[1], params[2], params[3]); + else if (dist_id == 16) return uniform_lcdf(delay | params[1], params[2]); + else if (dist_id == 17) return von_mises_lcdf(delay | params[1], params[2]); + else reject("Invalid distribution identifier"); +} + +/** + * Compute the log PDF of the primary distribution + * + * @param x Value + * @param primary_id Primary distribution identifier + * @param params Distribution parameters + * @param min Minimum value + * @param max Maximum value + * + * @return Log PDF of the primary distribution + * + * @code + * // Example: Uniform distribution + * real x = 0.5; + * int primary_id = 1; // Uniform + * array[0] real params = {}; // No additional parameters for uniform + * real min = 0; + * real max = 1; + * real log_pdf = primary_lpdf(x, primary_id, params, min, max); + * @endcode + */ +real primary_lpdf(real x, int primary_id, array[] real params, real min, real max) { + // Implement switch for different primary distributions + if (primary_id == 1) return uniform_lpdf(x | min, max); + if (primary_id == 2) return expgrowth_lpdf(x | min, max, params[1]); + // Add more primary distributions as needed + reject("Invalid primary distribution identifier"); +} + +/** + * ODE system for the primary censored distribution + * + * @param t Time + * @param y State variables + * @param theta Parameters + * @param x_r Real data + * @param x_i Integer data + * + * @return Derivatives of the state variables + */ +vector primarycensored_ode(real t, vector y, array[] real theta, + array[] real x_r, array[] int x_i) { + real d = x_r[1]; + int dist_id = x_i[1]; + int primary_id = x_i[2]; + real pwindow = x_r[2]; + int dist_params_len = x_i[3]; + int primary_params_len = x_i[4]; + + // Extract distribution parameters + array[dist_params_len] real params; + if (dist_params_len) { + params = theta[1:dist_params_len]; + } + array[primary_params_len] real primary_params; + if (primary_params_len) { + int primary_loc = size(theta); + primary_params = theta[primary_loc - primary_params_len + 1:primary_loc]; + } + + real log_cdf = dist_lcdf(t | params, dist_id); + real log_primary_pdf = primary_lpdf(d - t | primary_id, primary_params, 0, pwindow); + + return rep_vector(exp(log_cdf + log_primary_pdf), 1); +} + +/** + * Exponential growth probability density function (PDF) + * + * @param x Value at which to evaluate the PDF + * @param min Lower bound of the distribution + * @param max Upper bound of the distribution + * @param r Rate parameter for exponential growth + * @return The PDF evaluated at x + */ +real expgrowth_pdf(real x, real min, real max, real r) { + if (x < min || x > max) { + return 0; + } + if (abs(r) < 1e-10) { + return 1 / (max - min); + } + return r * exp(r * (x - min)) / (exp(r * max) - exp(r * min)); +} + +/** + * Exponential growth log probability density function (log PDF) + * + * @param x Value at which to evaluate the log PDF + * @param min Lower bound of the distribution + * @param max Upper bound of the distribution + * @param r Rate parameter for exponential growth + * @return The log PDF evaluated at x + */ +real expgrowth_lpdf(real x, real min, real max, real r) { + if (x < min || x > max) { + return negative_infinity(); + } + if (abs(r) < 1e-10) { + return -log(max - min); + } + return log(r) + r * (x - min) - log(exp(r * max) - exp(r * min)); +} + +/** + * Exponential growth cumulative distribution function (CDF) + * + * @param x Value at which to evaluate the CDF + * @param min Lower bound of the distribution + * @param max Upper bound of the distribution + * @param r Rate parameter for exponential growth + * @return The CDF evaluated at x + */ +real expgrowth_cdf(real x, real min, real max, real r) { + if (x < min) { + return 0; + } + if (x > max) { + return 1; + } + if (abs(r) < 1e-10) { + return (x - min) / (max - min); + } + return (exp(r * (x - min)) - exp(r * min)) / (exp(r * max) - exp(r * min)); +} + +/** + * Exponential growth log cumulative distribution function (log CDF) + * + * @param x Value at which to evaluate the log CDF + * @param min Lower bound of the distribution + * @param max Upper bound of the distribution + * @param r Rate parameter for exponential growth + * @return The log CDF evaluated at x + */ +real expgrowth_lcdf(real x, real min, real max, real r) { + if (x < min) { + return negative_infinity(); + } + if (x > max) { + return 0; + } + return log(expgrowth_cdf(x | min, max, r)); +} + +/** + * Exponential growth random number generator + * + * @param min Lower bound of the distribution + * @param max Upper bound of the distribution + * @param r Rate parameter for exponential growth + * @return A random draw from the exponential growth distribution + */ +real expgrowth_rng(real min, real max, real r) { + real u = uniform_rng(0, 1); + if (abs(r) < 1e-10) { + return min + u * (max - min); + } + return min + log(u * (exp(r * max) - exp(r * min)) + exp(r * min)) / r; +} From a107c2e092413ebe0fecc2273df27c585bf109e2 Mon Sep 17 00:00:00 2001 From: athowes Date: Wed, 13 Nov 2024 12:12:29 +0000 Subject: [PATCH 13/62] Wrap up on attempt --- inst/cohort-scratch.R | 6 +++--- inst/stan/cohort_model/primarycensored-edit.stan | 6 ++++-- 2 files changed, 7 insertions(+), 5 deletions(-) diff --git a/inst/cohort-scratch.R b/inst/cohort-scratch.R index 61d3b5b27..0131e8bf3 100644 --- a/inst/cohort-scratch.R +++ b/inst/cohort-scratch.R @@ -49,7 +49,7 @@ summary(fit_direct_weighted) lognormal <- brms::lognormal() -primarycensored_lognormal_uniform_lpmf <- brms::custom_family( +primarycensored_lognormal_uniform <- brms::custom_family( "primarycensored_lognormal_uniform", dpars = lognormal$dpar, links = c(lognormal$link, lognormal$link_sigma), @@ -92,7 +92,7 @@ stanvars_all <- stanvars_functions + stanvars_data stancode <- brms::make_stancode( formula = d | weights(n) + vreal(q) ~ 1, - family = primarycensored_lognormal_uniform_lpmf, + family = primarycensored_lognormal_uniform, data = data, stanvars = stanvars_all, ) @@ -101,7 +101,7 @@ model <- rstan::stan_model(model_code = stancode) fit_pcd <- brms::brm( formula = d | weights(n) + vreal(q) ~ 1, - family = primarycensored_lognormal_uniform_lpmf, + family = primarycensored_lognormal_uniform, data = data, stanvars = stanvars_all, backend = "cmdstanr" diff --git a/inst/stan/cohort_model/primarycensored-edit.stan b/inst/stan/cohort_model/primarycensored-edit.stan index 0b2899b70..d0a77774e 100644 --- a/inst/stan/cohort_model/primarycensored-edit.stan +++ b/inst/stan/cohort_model/primarycensored-edit.stan @@ -2,10 +2,12 @@ // Edited to work with temporary brms function real primarycensored_lognormal_uniform_lpmf(data int d, real mu, real sigma, real q, data real pwindow) { + int dist_id = 1; // lognormal array[2] real params = {mu, sigma}; - array[0] real primary_params; int d_upper = d + 1; - return primarycensored_lpmf(d | 1, params, pwindow, d_upper, positive_infinity(), 1, primary_params); + int primary_id = 1; // Uniform + array[0] real primary_params; + return primarycensored_lpmf(d | dist_id, params, pwindow, d_upper, positive_infinity(), primary_id, primary_params); } /** From 4a0b5c525dec362f9e8a03b1bfe5c4fffc5c567c Mon Sep 17 00:00:00 2001 From: athowes Date: Thu, 14 Nov 2024 11:14:16 +0000 Subject: [PATCH 14/62] Rename to marginal model --- inst/{cohort-scratch.R => marginal_model-scratch.R} | 0 inst/stan/{cohort_model => marginal_model}/data.stan | 0 .../{cohort_model => marginal_model}/primarycensored-edit.stan | 0 inst/stan/{cohort_model => marginal_model}/tdata.stan | 0 inst/stan/{cohort_model => marginal_model}/tparameters.stan | 0 5 files changed, 0 insertions(+), 0 deletions(-) rename inst/{cohort-scratch.R => marginal_model-scratch.R} (100%) rename inst/stan/{cohort_model => marginal_model}/data.stan (100%) rename inst/stan/{cohort_model => marginal_model}/primarycensored-edit.stan (100%) rename inst/stan/{cohort_model => marginal_model}/tdata.stan (100%) rename inst/stan/{cohort_model => marginal_model}/tparameters.stan (100%) diff --git a/inst/cohort-scratch.R b/inst/marginal_model-scratch.R similarity index 100% rename from inst/cohort-scratch.R rename to inst/marginal_model-scratch.R diff --git a/inst/stan/cohort_model/data.stan b/inst/stan/marginal_model/data.stan similarity index 100% rename from inst/stan/cohort_model/data.stan rename to inst/stan/marginal_model/data.stan diff --git a/inst/stan/cohort_model/primarycensored-edit.stan b/inst/stan/marginal_model/primarycensored-edit.stan similarity index 100% rename from inst/stan/cohort_model/primarycensored-edit.stan rename to inst/stan/marginal_model/primarycensored-edit.stan diff --git a/inst/stan/cohort_model/tdata.stan b/inst/stan/marginal_model/tdata.stan similarity index 100% rename from inst/stan/cohort_model/tdata.stan rename to inst/stan/marginal_model/tdata.stan diff --git a/inst/stan/cohort_model/tparameters.stan b/inst/stan/marginal_model/tparameters.stan similarity index 100% rename from inst/stan/cohort_model/tparameters.stan rename to inst/stan/marginal_model/tparameters.stan From 6bf847916e67d4f743636faf8a686109adc3109d Mon Sep 17 00:00:00 2001 From: athowes Date: Fri, 15 Nov 2024 21:04:27 +0000 Subject: [PATCH 15/62] Move towards single wrap function with others imported --- inst/marginal_model-scratch.R | 29 +- inst/stan/marginal_model/functions.stan | 8 + .../marginal_model/primarycensored-edit.stan | 846 ------------------ 3 files changed, 20 insertions(+), 863 deletions(-) create mode 100644 inst/stan/marginal_model/functions.stan delete mode 100644 inst/stan/marginal_model/primarycensored-edit.stan diff --git a/inst/marginal_model-scratch.R b/inst/marginal_model-scratch.R index 0131e8bf3..3753e10c1 100644 --- a/inst/marginal_model-scratch.R +++ b/inst/marginal_model-scratch.R @@ -19,7 +19,7 @@ sim_obs <- simulate_gillespie() |> sdlog = sdlog ) |> observe_process() |> - filter_obs_by_obs_time(obs_time = obs_time) |> + filter(.data$stime_upr <= obs_time) |> dplyr::slice_sample(n = sample_size, replace = FALSE) # Create cohort version of data @@ -49,8 +49,8 @@ summary(fit_direct_weighted) lognormal <- brms::lognormal() -primarycensored_lognormal_uniform <- brms::custom_family( - "primarycensored_lognormal_uniform", +primarycensored_family <- brms::custom_family( + "primarycensored_wrapper", dpars = lognormal$dpar, links = c(lognormal$link, lognormal$link_sigma), type = "int", @@ -65,20 +65,15 @@ data <- cohort_obs |> q = pmax(d - pwindow, 0) ) -stanvars_functions <- brms::stanvar( +pcd_stanvars_functions <- brms::stanvar( block = "functions", - scode = .stan_chunk("cohort_model/primarycensored-edit.stan") + scode = pcd_load_stan_functions() ) -# stanvars_tparameters <- brms::stanvar( -# block = "tparameters", -# scode = .stan_chunk("cohort_model/tparameters.stan") -# ) - -# stanvars_tdata <- brms::stanvar( -# block = "tdata", -# scode = .stan_chunk("cohort_model/tdata.stan") -# ) +stanvars_functions <- brms::stanvar( + block = "functions", + scode = .stan_chunk("cohort_model/functions.stan") +) pwindow <- data$pwindow @@ -88,11 +83,11 @@ stanvars_data <- brms::stanvar( scode = .stan_chunk("cohort_model/data.stan") ) -stanvars_all <- stanvars_functions + stanvars_data +stanvars_all <- pcd_stanvars_functions + stanvars_functions + stanvars_data stancode <- brms::make_stancode( formula = d | weights(n) + vreal(q) ~ 1, - family = primarycensored_lognormal_uniform, + family = primarycensored_family, data = data, stanvars = stanvars_all, ) @@ -101,7 +96,7 @@ model <- rstan::stan_model(model_code = stancode) fit_pcd <- brms::brm( formula = d | weights(n) + vreal(q) ~ 1, - family = primarycensored_lognormal_uniform, + family = primarycensored_family, data = data, stanvars = stanvars_all, backend = "cmdstanr" diff --git a/inst/stan/marginal_model/functions.stan b/inst/stan/marginal_model/functions.stan new file mode 100644 index 000000000..5df0ee1c6 --- /dev/null +++ b/inst/stan/marginal_model/functions.stan @@ -0,0 +1,8 @@ +real primarycensored_wrapper_lpmf(data int d, real mu, real sigma, real q, data real pwindow) { + int dist_id = 1; // lognormal + array[2] real params = {mu, sigma}; + int d_upper = d + 1; + int primary_id = 1; // Uniform + array[0] real primary_params; + return primarycensored_lpmf(d | dist_id, params, pwindow, d_upper, positive_infinity(), primary_id, primary_params); +} diff --git a/inst/stan/marginal_model/primarycensored-edit.stan b/inst/stan/marginal_model/primarycensored-edit.stan deleted file mode 100644 index d0a77774e..000000000 --- a/inst/stan/marginal_model/primarycensored-edit.stan +++ /dev/null @@ -1,846 +0,0 @@ -// Copied from https://github.com/epinowcast/primarycensored/blob/main/inst/stan/functions/ -// Edited to work with temporary brms function - -real primarycensored_lognormal_uniform_lpmf(data int d, real mu, real sigma, real q, data real pwindow) { - int dist_id = 1; // lognormal - array[2] real params = {mu, sigma}; - int d_upper = d + 1; - int primary_id = 1; // Uniform - array[0] real primary_params; - return primarycensored_lpmf(d | dist_id, params, pwindow, d_upper, positive_infinity(), primary_id, primary_params); -} - -/** - * Primary event censored distribution functions - */ - -/** - * Compute the primary event censored CDF for a single delay - * - * @param d Delay - * @param dist_id Distribution identifier - * @param params Array of distribution parameters - * @param pwindow Primary event window - * @param D Maximum delay (truncation point) - * @param primary_id Primary distribution identifier - * @param primary_params Primary distribution parameters - * - * @return Primary event censored CDF, normalized by D if finite (truncation adjustment) - */ -real primarycensored_cdf(data real d, int dist_id, array[] real params, - data real pwindow, data real D, - int primary_id, - array[] real primary_params) { - real result; - if (d <= 0) { - return 0; - } - - if (d >= D) { - return 1; - } - - // Check if an analytical solution exists - if (check_for_analytical(dist_id, primary_id)) { - // Use analytical solution - result = primarycensored_analytical_cdf( - d | dist_id, params, pwindow, D, primary_id, primary_params - ); - } else { - // Use numerical integration for other cases - real lower_bound = max({d - pwindow, 1e-6}); - array[size(params) + size(primary_params)] real theta = append_array(params, primary_params); - array[4] int ids = {dist_id, primary_id, size(params), size(primary_params)}; - - vector[1] y0 = rep_vector(0.0, 1); - result = ode_rk45(primarycensored_ode, y0, lower_bound, {d}, theta, {d, pwindow}, ids)[1, 1]; - - if (!is_inf(D)) { - real log_cdf_D = primarycensored_lcdf( - D | dist_id, params, pwindow, positive_infinity(), primary_id,primary_params - ); - result = exp(log(result) - log_cdf_D); - } - } - - return result; -} - -/** - * Compute the primary event censored log CDF for a single delay - * - * @param d Delay - * @param dist_id Distribution identifier - * @param params Array of distribution parameters - * @param pwindow Primary event window - * @param D Maximum delay (truncation point) - * @param primary_id Primary distribution identifier - * @param primary_params Primary distribution parameters - * - * @return Primary event censored log CDF, normalized by D if finite (truncation adjustment) - * - * @code - * // Example: Weibull delay distribution with uniform primary distribution - * real d = 3.0; - * int dist_id = 5; // Weibull - * array[2] real params = {2.0, 1.5}; // shape and scale - * real pwindow = 1.0; - * real D = positive_infinity(); - * int primary_id = 1; // Uniform - * array[0] real primary_params = {}; - * real log_cdf = primarycensored_lcdf( - * d, dist_id, params, pwindow, D, primary_id, primary_params - * ); - * @endcode - */ -real primarycensored_lcdf(data real d, int dist_id, array[] real params, - data real pwindow, data real D, - int primary_id, - array[] real primary_params) { - real result; - - if (d <= 0) { - return negative_infinity(); - } - - if (d >= D) { - return 0; - } - - // Check if an analytical solution exists - if (check_for_analytical(dist_id, primary_id)) { - result = primarycensored_analytical_lcdf( - d | dist_id, params, pwindow, positive_infinity(), primary_id, primary_params - ); - } else { - // Use numerical integration - result = log(primarycensored_cdf( - d | dist_id, params, pwindow, positive_infinity(), primary_id, primary_params - )); - } - - // Handle truncation - if (!is_inf(D)) { - real log_cdf_D = primarycensored_lcdf( - D | dist_id, params, pwindow, positive_infinity(), primary_id, primary_params - ); - result = result - log_cdf_D; - } - - return result; -} - -/** - * Compute the primary event censored log PMF for a single delay - * - * @param d Delay (integer) - * @param dist_id Distribution identifier - * @param params Array of distribution parameters - * @param pwindow Primary event window - * @param d_upper Upper bound for the delay interval - * @param D Maximum delay (truncation point) - * @param primary_id Primary distribution identifier - * @param primary_params Primary distribution parameters - * - * @return Primary event censored log PMF, normalized by D if finite (truncation adjustment) - * - * @code - * // Example: Weibull delay distribution with uniform primary distribution - * int d = 3; - * int dist_id = 5; // Weibull - * array[2] real params = {2.0, 1.5}; // shape and scale - * real pwindow = 1.0; - * real d_upper = 4.0; - * real D = positive_infinity(); - * int primary_id = 1; // Uniform - * array[0] real primary_params = {}; - * real log_pmf = primarycensored_lpmf( - * d, dist_id, params, pwindow, d_upper, D, primary_id, primary_params - * ); - * @endcode - */ -real primarycensored_lpmf(data int d, int dist_id, array[] real params, - data real pwindow, data real d_upper, - data real D, int primary_id, - array[] real primary_params) { - if (d_upper > D) { - reject("Upper truncation point is greater than D. It is ", d_upper, - " and D is ", D, ". Resolve this by increasing D to be greater or equal to d + swindow or decreasing swindow."); - } - if (d_upper <= d) { - reject("Upper truncation point is less than or equal to d. It is ", d_upper, - " and d is ", d, ". Resolve this by increasing d to be less than d_upper."); - } - real log_cdf_upper = primarycensored_lcdf( - d_upper | dist_id, params, pwindow, positive_infinity(), primary_id, primary_params - ); - real log_cdf_lower = primarycensored_lcdf( - d | dist_id, params, pwindow, positive_infinity(), primary_id, primary_params - ); - if (!is_inf(D)) { - real log_cdf_D; - - if (d_upper == D) { - log_cdf_D = log_cdf_upper; - } else { - log_cdf_D = primarycensored_lcdf( - D | dist_id, params, pwindow, positive_infinity(), primary_id, primary_params - ); - } - return log_diff_exp(log_cdf_upper, log_cdf_lower) - log_cdf_D; - } else { - return log_diff_exp(log_cdf_upper, log_cdf_lower); - } -} - -/** - * Compute the primary event censored PMF for a single delay - * - * @param d Delay (integer) - * @param dist_id Distribution identifier - * @param params Array of distribution parameters - * @param pwindow Primary event window - * @param d_upper Upper bound for the delay interval - * @param D Maximum delay (truncation point) - * @param primary_id Primary distribution identifier - * @param primary_params Primary distribution parameters - * - * @return Primary event censored PMF, normalized by D if finite (truncation adjustment) - * - * @code - * // Example: Weibull delay distribution with uniform primary distribution - * int d = 3; - * real d = 3.0; - * int dist_id = 5; // Weibull - * array[2] real params = {2.0, 1.5}; // shape and scale - * real pwindow = 1.0; - * real swindow = 0.1; - * real D = positive_infinity(); - * int primary_id = 1; // Uniform - * array[0] real primary_params = {}; - * real pmf = primarycensored_pmf(d, dist_id, params, pwindow, swindow, D, primary_id, primary_params); - * @endcode - */ -real primarycensored_pmf(data int d, int dist_id, array[] real params, - data real pwindow, data real d_upper, - data real D, int primary_id, - array[] real primary_params) { - return exp( - primarycensored_lpmf( - d | dist_id, params, pwindow, d_upper, D, primary_id, primary_params - ) - ); -} - -/** - * Compute the primary event censored log PMF for integer delays up to max_delay - * - * @param max_delay Maximum delay to compute PMF for - * @param D Maximum delay (truncation point), must be at least max_delay + 1 - * @param dist_id Distribution identifier - * @param params Array of distribution parameters - * @param pwindow Primary event window - * @param primary_id Primary distribution identifier - * @param primary_params Primary distribution parameters - * - * @return Vector of primary event censored log PMFs for delays \[0, 1\] to - * \[max_delay, max_delay + 1\]. - * - * This function differs from primarycensored_lpmf in that it: - * 1. Computes PMFs for all integer delays from \[0, 1\] to \[max_delay, - * max_delay + 1\] in one call. - * 2. Assumes integer delays (swindow = 1) - * 3. Is more computationally efficient for multiple delay calculation as it - * reduces the number of integration calls. - * - * @code - * // Example: Weibull delay distribution with uniform primary distribution - * int max_delay = 10; - * real D = 15.0; - * int dist_id = 5; // Weibull - * array[2] real params = {2.0, 1.5}; // shape and scale - * real pwindow = 7.0; - * int primary_id = 1; // Uniform - * array[0] real primary_params = {}; - - * vector[max_delay] log_pmf = - * primarycensored_sone_lpmf_vectorized( - * max_delay, D, dist_id, params, pwindow, primary_id, - * primary_params - * ); - * @endcode - */ -vector primarycensored_sone_lpmf_vectorized( - int max_delay, data real D, int dist_id, - array[] real params, data real pwindow, - int primary_id, array[] real primary_params -) { - - int upper_interval = max_delay + 1; - vector[upper_interval] log_pmfs; - vector[upper_interval] log_cdfs; - real log_normalizer; - - // Check if D is at least max_delay + 1 - if (D < upper_interval) { - reject("D must be at least max_delay + 1"); - } - - // Compute log CDFs - for (d in 1:upper_interval) { - log_cdfs[d] = primarycensored_lcdf( - d | dist_id, params, pwindow, positive_infinity(), primary_id, - primary_params - ); - } - - // Compute log normalizer using upper_interval - if (D > upper_interval) { - if (is_inf(D)) { - log_normalizer = 0; // No normalization needed for infinite D - } else { - log_normalizer = primarycensored_lcdf( - D | dist_id, params, pwindow, positive_infinity(), - primary_id, primary_params - ); - } - } else { - log_normalizer = log_cdfs[upper_interval]; - } - - // Compute log PMFs - log_pmfs[1] = log_cdfs[1] - log_normalizer; - for (d in 2:upper_interval) { - log_pmfs[d] = log_diff_exp(log_cdfs[d], log_cdfs[d-1]) - log_normalizer; - } - - return log_pmfs; -} - -/** - * Compute the primary event censored PMF for integer delays up to max_delay - * - * @param max_delay Maximum delay to compute PMF for - * @param D Maximum delay (truncation point), must be at least max_delay + 1 - * @param dist_id Distribution identifier - * @param params Array of distribution parameters - * @param pwindow Primary event window - * @param primary_id Primary distribution identifier - * @param primary_params Primary distribution parameters - * - * @return Vector of primary event censored PMFs for integer delays 1 to - * max_delay - * - * This function differs from primarycensored_pmf in that it: - * 1. Computes PMFs for all integer delays from \[0, 1\] to \[max_delay, - * max_delay + 1\] in one call. - * 2. Assumes integer delays (swindow = 1) - * 3. Is more computationally efficient for multiple delay calculations - * - * @code - * // Example: Weibull delay distribution with uniform primary distribution - * int max_delay = 10; - * real D = 15.0; - * int dist_id = 5; // Weibull - * array[2] real params = {2.0, 1.5}; // shape and scale - * real pwindow = 7.0; - * int primary_id = 1; // Uniform - * array[0] real primary_params = {}; - * vector[max_delay] pmf = - * primarycensored_sone_lpmf_vectorized( - * max_delay, D, dist_id, params, pwindow, primary_id, primary_params - * ); - * @endcode - */ -vector primarycensored_sone_pmf_vectorized( - int max_delay, data real D, int dist_id, - array[] real params, data real pwindow, - int primary_id, - array[] real primary_params -) { - return exp( - primarycensored_sone_lpmf_vectorized( - max_delay, D, dist_id, params, pwindow, primary_id, primary_params - ) - ); -} - - -/** - * Check if an analytical solution exists for the given distribution combination - * - * @param dist_id Distribution identifier for the delay distribution - * @param primary_id Distribution identifier for the primary distribution - * - * @return 1 if an analytical solution exists, 0 otherwise - */ -int check_for_analytical(int dist_id, int primary_id) { - if (dist_id == 2 && primary_id == 1) return 1; // Gamma delay with Uniform primary - if (dist_id == 1 && primary_id == 1) return 1; // Lognormal delay with Uniform primary - if (dist_id == 3 && primary_id == 1) return 1; // Weibull delay with Uniform primary - return 0; // No analytical solution for other combinations -} - -/** - * Compute the primary event censored log CDF analytically for Gamma delay with Uniform primary - * - * @param d Delay time - * @param q Lower bound of integration (max(d - pwindow, 0)) - * @param params Array of Gamma distribution parameters [shape, rate] - * @param pwindow Primary event window - * - * @return Log of the primary event censored CDF for Gamma delay with Uniform - * primary - */ -real primarycensored_gamma_uniform_lcdf(data real d, real q, array[] real params, data real pwindow) { - real shape = params[1]; - real rate = params[2]; - real shape_1 = shape + 1; - real log_window = log(pwindow); - - real log_F_T = gamma_lcdf(d | shape, rate); - real log_F_T_kp1 = gamma_lcdf(d | shape_1, rate); - - real log_delta_F_T_kp1; - real log_delta_F_T_k; - real log_F_Splus; - - if (q != 0) { - real log_F_T_q = gamma_lcdf(q | shape, rate); - real log_F_T_q_kp1 = gamma_lcdf(q | shape_1, rate); - - // Ensure that the first argument is greater than the second - log_delta_F_T_kp1 = log_diff_exp(log_F_T_kp1, log_F_T_q_kp1); - log_delta_F_T_k = log_diff_exp(log_F_T, log_F_T_q); - - log_F_Splus = log_diff_exp( - log_F_T, - log_diff_exp( - log(shape * inv(rate)) + log_delta_F_T_kp1, - log(d - pwindow) + log_delta_F_T_k - ) - log_window - ); - } else { - log_delta_F_T_kp1 = log_F_T_kp1; - log_delta_F_T_k = log_F_T; - - log_F_Splus = log_diff_exp( - log_F_T, - log_sum_exp( - log(shape * inv(rate)) + log_delta_F_T_kp1, - log(pwindow - d) + log_delta_F_T_k - ) - log_window - ); - } - - return log_F_Splus; -} - -/** - * Compute the primary event censored log CDF analytically for Lognormal delay with Uniform primary - * - * @param d Delay time - * @param q Lower bound of integration (max(d - pwindow, 0)) - * @param params Array of Lognormal distribution parameters [mu, sigma] - * @param pwindow Primary event window - * - * @return Log of the primary event censored CDF for Lognormal delay with - * Uniform primary - */ -real primarycensored_lognormal_uniform_lcdf(data real d, real q, array[] real params, data real pwindow) { - real mu = params[1]; - real sigma = params[2]; - real mu_sigma2 = mu + square(sigma); - real log_window = log(pwindow); - - real log_F_T = lognormal_lcdf(d | mu, sigma); - real log_F_T_mu_sigma2 = lognormal_lcdf(d | mu_sigma2, sigma); - - real log_delta_F_T_mu_sigma; - real log_delta_F_T; - real log_F_Splus; - - if (q != 0) { - real log_F_T_q = lognormal_lcdf(q | mu, sigma); - real log_F_T_q_mu_sigma2 = lognormal_lcdf(q | mu_sigma2, sigma); - - // Ensure that the first argument is greater than the second - log_delta_F_T_mu_sigma = log_diff_exp( - log_F_T_mu_sigma2, log_F_T_q_mu_sigma2 - ); - log_delta_F_T = log_diff_exp(log_F_T, log_F_T_q); - - log_F_Splus = log_diff_exp( - log_F_T, - log_diff_exp( - (mu + 0.5 * square(sigma)) + log_delta_F_T_mu_sigma, - log(d - pwindow) + log_delta_F_T - ) - log_window - ); - } else { - log_delta_F_T_mu_sigma = log_F_T_mu_sigma2; - log_delta_F_T = log_F_T; - - log_F_Splus = log_diff_exp( - log_F_T, - log_sum_exp( - (mu + 0.5 * square(sigma)) + log_delta_F_T_mu_sigma, - log(pwindow - d) + log_delta_F_T - ) - log_window - ); - } - - return log_F_Splus; -} - -/** - * Compute the log of the lower incomplete gamma function - * - * This function is used in the analytical solution for the primary censored - * Weibull distribution with uniform primary censoring. It corresponds to the - * g(t; λ, k) function described in the analytic solutions document. - * - * @param t Upper bound of integration - * @param shape Shape parameter (k) of the Weibull distribution - * @param scale Scale parameter (λ) of the Weibull distribution - * - * @return Log of g(t; λ, k) = γ(1 + 1/k, (t/λ)^k) - */ -real log_weibull_g(real t, real shape, real scale) { - real x = pow(t * inv(scale), shape); - real a = 1 + inv(shape); - return log(gamma_p(a, x)) + lgamma(a); -} - -/** - * Compute the primary event censored log CDF analytically for Weibull delay with Uniform primary - * - * @param d Delay time - * @param q Lower bound of integration (max(d - pwindow, 0)) - * @param params Array of Weibull distribution parameters [shape, scale] - * @param pwindow Primary event window - * - * @return Log of the primary event censored CDF for Weibull delay with - * Uniform primary - */ -real primarycensored_weibull_uniform_lcdf(data real d, real q, array[] real params, data real pwindow) { - real shape = params[1]; - real scale = params[2]; - real log_window = log(pwindow); - - real log_F_T = weibull_lcdf(d | shape, scale); - - real log_delta_g; - real log_delta_F_T; - real log_F_Splus; - - if (q != 0) { - real log_F_T_q = weibull_lcdf(q | shape, scale); - - log_delta_g = log_diff_exp( - log_weibull_g(d, shape, scale), - log_weibull_g(q, shape, scale) - ); - log_delta_F_T = log_diff_exp(log_F_T, log_F_T_q); - - log_F_Splus = log_diff_exp( - log_F_T, - log_diff_exp( - log(scale) + log_delta_g, - log(d - pwindow) + log_delta_F_T - ) - log_window - ); - } else { - log_delta_g = log_weibull_g(d, shape, scale); - log_delta_F_T = log_F_T; - - log_F_Splus = log_diff_exp( - log_F_T, - log_sum_exp( - log(scale) + log_delta_g, - log(pwindow - d) + log_delta_F_T - ) - log_window - ); - } - - return log_F_Splus; -} - -/** - * Compute the primary event censored log CDF analytically for a single delay - * - * @param d Delay - * @param dist_id Distribution identifier - * @param params Array of distribution parameters - * @param pwindow Primary event window - * @param D Maximum delay (truncation point) - * @param primary_id Primary distribution identifier - * @param primary_params Primary distribution parameters - * - * @return Primary event censored log CDF, normalized by D if finite (truncation adjustment) - */ -real primarycensored_analytical_lcdf(data real d, int dist_id, - array[] real params, - data real pwindow, data real D, - int primary_id, - array[] real primary_params) { - real result; - real log_cdf_D; - - if (d <= 0) return negative_infinity(); - if (d >= D) return 0; - - real q = max({d - pwindow, 0}); - - if (dist_id == 2 && primary_id == 1) { - // Gamma delay with Uniform primary - result = primarycensored_gamma_uniform_lcdf(d | q, params, pwindow); - } else if (dist_id == 1 && primary_id == 1) { - // Lognormal delay with Uniform primary - result = primarycensored_lognormal_uniform_lcdf(d | q, params, pwindow); - } else if (dist_id == 3 && primary_id == 1) { - // Weibull delay with Uniform primary - result = primarycensored_weibull_uniform_lcdf(d | q, params, pwindow); - } else { - // No analytical solution available - return negative_infinity(); - } - - if (!is_inf(D)) { - log_cdf_D = primarycensored_lcdf( - D | dist_id, params, pwindow, positive_infinity(), - primary_id, primary_params - ); - result = result - log_cdf_D; - } - - return result; -} - -/** - * Compute the primary event censored CDF analytically for a single delay - * - * @param d Delay - * @param dist_id Distribution identifier - * @param params Array of distribution parameters - * @param pwindow Primary event window - * @param D Maximum delay (truncation point) - * @param primary_id Primary distribution identifier - * @param primary_params Primary distribution parameters - * - * @return Primary event censored CDF, normalized by D if finite (truncation adjustment) - */ -real primarycensored_analytical_cdf(data real d, int dist_id, - array[] real params, - data real pwindow, data real D, - int primary_id, - array[] real primary_params) { - return exp(primarycensored_analytical_lcdf(d | dist_id, params, pwindow, D, primary_id, primary_params)); -} - -/** - * Compute the log CDF of the delay distribution - * - * @param delay Time delay - * @param params Distribution parameters - * @param dist_id Distribution identifier - * 1: Lognormal, 2: Gamma, 3: Normal, 4: Exponential, 5: Weibull, - * 6: Beta, 7: Cauchy, 8: Chi-square, 9: Inverse Chi-square, - * 10: Double Exponential, 11: Inverse Gamma, 12: Logistic, - * 13: Pareto, 14: Scaled Inverse Chi-square, 15: Student's t, - * 16: Uniform, 17: von Mises - * - * @return Log CDF of the delay distribution - * - * @code - * // Example: Lognormal distribution - * real delay = 5.0; - * array[2] real params = {0.0, 1.0}; // mean and standard deviation on log scale - * int dist_id = 1; // Lognormal - * real log_cdf = dist_lcdf(delay, params, dist_id); - * @endcode - */ -real dist_lcdf(real delay, array[] real params, int dist_id) { - if (delay <= 0) return negative_infinity(); - - // Use if-else statements to handle different distribution types - if (dist_id == 1) return lognormal_lcdf(delay | params[1], params[2]); - else if (dist_id == 2) return gamma_lcdf(delay | params[1], params[2]); - else if (dist_id == 3) return normal_lcdf(delay | params[1], params[2]); - else if (dist_id == 4) return exponential_lcdf(delay | params[1]); - else if (dist_id == 5) return weibull_lcdf(delay | params[1], params[2]); - else if (dist_id == 6) return beta_lcdf(delay | params[1], params[2]); - else if (dist_id == 7) return cauchy_lcdf(delay | params[1], params[2]); - else if (dist_id == 8) return chi_square_lcdf(delay | params[1]); - else if (dist_id == 9) return inv_chi_square_lcdf(delay | params[1]); - else if (dist_id == 10) return double_exponential_lcdf(delay | params[1], params[2]); - else if (dist_id == 11) return inv_gamma_lcdf(delay | params[1], params[2]); - else if (dist_id == 12) return logistic_lcdf(delay | params[1], params[2]); - else if (dist_id == 13) return pareto_lcdf(delay | params[1], params[2]); - else if (dist_id == 14) return scaled_inv_chi_square_lcdf(delay | params[1], params[2]); - else if (dist_id == 15) return student_t_lcdf(delay | params[1], params[2], params[3]); - else if (dist_id == 16) return uniform_lcdf(delay | params[1], params[2]); - else if (dist_id == 17) return von_mises_lcdf(delay | params[1], params[2]); - else reject("Invalid distribution identifier"); -} - -/** - * Compute the log PDF of the primary distribution - * - * @param x Value - * @param primary_id Primary distribution identifier - * @param params Distribution parameters - * @param min Minimum value - * @param max Maximum value - * - * @return Log PDF of the primary distribution - * - * @code - * // Example: Uniform distribution - * real x = 0.5; - * int primary_id = 1; // Uniform - * array[0] real params = {}; // No additional parameters for uniform - * real min = 0; - * real max = 1; - * real log_pdf = primary_lpdf(x, primary_id, params, min, max); - * @endcode - */ -real primary_lpdf(real x, int primary_id, array[] real params, real min, real max) { - // Implement switch for different primary distributions - if (primary_id == 1) return uniform_lpdf(x | min, max); - if (primary_id == 2) return expgrowth_lpdf(x | min, max, params[1]); - // Add more primary distributions as needed - reject("Invalid primary distribution identifier"); -} - -/** - * ODE system for the primary censored distribution - * - * @param t Time - * @param y State variables - * @param theta Parameters - * @param x_r Real data - * @param x_i Integer data - * - * @return Derivatives of the state variables - */ -vector primarycensored_ode(real t, vector y, array[] real theta, - array[] real x_r, array[] int x_i) { - real d = x_r[1]; - int dist_id = x_i[1]; - int primary_id = x_i[2]; - real pwindow = x_r[2]; - int dist_params_len = x_i[3]; - int primary_params_len = x_i[4]; - - // Extract distribution parameters - array[dist_params_len] real params; - if (dist_params_len) { - params = theta[1:dist_params_len]; - } - array[primary_params_len] real primary_params; - if (primary_params_len) { - int primary_loc = size(theta); - primary_params = theta[primary_loc - primary_params_len + 1:primary_loc]; - } - - real log_cdf = dist_lcdf(t | params, dist_id); - real log_primary_pdf = primary_lpdf(d - t | primary_id, primary_params, 0, pwindow); - - return rep_vector(exp(log_cdf + log_primary_pdf), 1); -} - -/** - * Exponential growth probability density function (PDF) - * - * @param x Value at which to evaluate the PDF - * @param min Lower bound of the distribution - * @param max Upper bound of the distribution - * @param r Rate parameter for exponential growth - * @return The PDF evaluated at x - */ -real expgrowth_pdf(real x, real min, real max, real r) { - if (x < min || x > max) { - return 0; - } - if (abs(r) < 1e-10) { - return 1 / (max - min); - } - return r * exp(r * (x - min)) / (exp(r * max) - exp(r * min)); -} - -/** - * Exponential growth log probability density function (log PDF) - * - * @param x Value at which to evaluate the log PDF - * @param min Lower bound of the distribution - * @param max Upper bound of the distribution - * @param r Rate parameter for exponential growth - * @return The log PDF evaluated at x - */ -real expgrowth_lpdf(real x, real min, real max, real r) { - if (x < min || x > max) { - return negative_infinity(); - } - if (abs(r) < 1e-10) { - return -log(max - min); - } - return log(r) + r * (x - min) - log(exp(r * max) - exp(r * min)); -} - -/** - * Exponential growth cumulative distribution function (CDF) - * - * @param x Value at which to evaluate the CDF - * @param min Lower bound of the distribution - * @param max Upper bound of the distribution - * @param r Rate parameter for exponential growth - * @return The CDF evaluated at x - */ -real expgrowth_cdf(real x, real min, real max, real r) { - if (x < min) { - return 0; - } - if (x > max) { - return 1; - } - if (abs(r) < 1e-10) { - return (x - min) / (max - min); - } - return (exp(r * (x - min)) - exp(r * min)) / (exp(r * max) - exp(r * min)); -} - -/** - * Exponential growth log cumulative distribution function (log CDF) - * - * @param x Value at which to evaluate the log CDF - * @param min Lower bound of the distribution - * @param max Upper bound of the distribution - * @param r Rate parameter for exponential growth - * @return The log CDF evaluated at x - */ -real expgrowth_lcdf(real x, real min, real max, real r) { - if (x < min) { - return negative_infinity(); - } - if (x > max) { - return 0; - } - return log(expgrowth_cdf(x | min, max, r)); -} - -/** - * Exponential growth random number generator - * - * @param min Lower bound of the distribution - * @param max Upper bound of the distribution - * @param r Rate parameter for exponential growth - * @return A random draw from the exponential growth distribution - */ -real expgrowth_rng(real min, real max, real r) { - real u = uniform_rng(0, 1); - if (abs(r) < 1e-10) { - return min + u * (max - min); - } - return min + log(u * (exp(r * max) - exp(r * min)) + exp(r * min)) / r; -} From c56adf77614801034b977d345ca6d30e0e9f0c19 Mon Sep 17 00:00:00 2001 From: athowes Date: Fri, 15 Nov 2024 21:35:19 +0000 Subject: [PATCH 16/62] Just running into some C++ errors.. --- inst/marginal_model-scratch.R | 31 +++++++++++-------------- inst/stan/marginal_model/data.stan | 1 + inst/stan/marginal_model/functions.stan | 2 +- inst/stan/marginal_model/tdata.stan | 1 - 4 files changed, 15 insertions(+), 20 deletions(-) delete mode 100644 inst/stan/marginal_model/tdata.stan diff --git a/inst/marginal_model-scratch.R b/inst/marginal_model-scratch.R index 3753e10c1..b73355a64 100644 --- a/inst/marginal_model-scratch.R +++ b/inst/marginal_model-scratch.R @@ -55,15 +55,12 @@ primarycensored_family <- brms::custom_family( links = c(lognormal$link, lognormal$link_sigma), type = "int", loop = TRUE, - vars = c("vreal1[n]", "pwindow[n]") + vars = c("vreal1[n]") ) data <- cohort_obs |> select(d = delay, n = n) |> - mutate( - pwindow = 1, - q = pmax(d - pwindow, 0) - ) + mutate(pwindow = 1) pcd_stanvars_functions <- brms::stanvar( block = "functions", @@ -72,30 +69,28 @@ pcd_stanvars_functions <- brms::stanvar( stanvars_functions <- brms::stanvar( block = "functions", - scode = .stan_chunk("cohort_model/functions.stan") + scode = .stan_chunk(file.path("marginal_model", "functions.stan")) ) -pwindow <- data$pwindow +# pwindow <- data$pwindow +# +# stanvars_data <- brms::stanvar( +# x = pwindow, +# block = "data", +# scode = .stan_chunk("marginal_model/data.stan") +# ) -stanvars_data <- brms::stanvar( - x = pwindow, - block = "data", - scode = .stan_chunk("cohort_model/data.stan") -) - -stanvars_all <- pcd_stanvars_functions + stanvars_functions + stanvars_data +stanvars_all <- pcd_stanvars_functions + stanvars_functions stancode <- brms::make_stancode( - formula = d | weights(n) + vreal(q) ~ 1, + formula = d | weights(n) + vreal(pwindow) ~ 1, family = primarycensored_family, data = data, stanvars = stanvars_all, ) -model <- rstan::stan_model(model_code = stancode) - fit_pcd <- brms::brm( - formula = d | weights(n) + vreal(q) ~ 1, + formula = d | weights(n) + vreal(pwindow) ~ 1, family = primarycensored_family, data = data, stanvars = stanvars_all, diff --git a/inst/stan/marginal_model/data.stan b/inst/stan/marginal_model/data.stan index 00a3e0b49..c170a774e 100644 --- a/inst/stan/marginal_model/data.stan +++ b/inst/stan/marginal_model/data.stan @@ -1 +1,2 @@ vector[N] pwindow; +pwindow = vreal2; diff --git a/inst/stan/marginal_model/functions.stan b/inst/stan/marginal_model/functions.stan index 5df0ee1c6..d76eef67c 100644 --- a/inst/stan/marginal_model/functions.stan +++ b/inst/stan/marginal_model/functions.stan @@ -1,4 +1,4 @@ -real primarycensored_wrapper_lpmf(data int d, real mu, real sigma, real q, data real pwindow) { +real primarycensored_wrapper_lpmf(data int d, real mu, real sigma, data real pwindow) { int dist_id = 1; // lognormal array[2] real params = {mu, sigma}; int d_upper = d + 1; diff --git a/inst/stan/marginal_model/tdata.stan b/inst/stan/marginal_model/tdata.stan deleted file mode 100644 index 0e3a441b1..000000000 --- a/inst/stan/marginal_model/tdata.stan +++ /dev/null @@ -1 +0,0 @@ -vector[N] q = fmax(Y - pwindow, 0); From 419cb441dbb0e7f8c466559030d066ac3a1c39a9 Mon Sep 17 00:00:00 2001 From: athowes Date: Fri, 15 Nov 2024 21:42:13 +0000 Subject: [PATCH 17/62] This doesn't change anything --- inst/stan/marginal_model/functions.stan | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/inst/stan/marginal_model/functions.stan b/inst/stan/marginal_model/functions.stan index d76eef67c..fc7e9ee92 100644 --- a/inst/stan/marginal_model/functions.stan +++ b/inst/stan/marginal_model/functions.stan @@ -1,6 +1,8 @@ real primarycensored_wrapper_lpmf(data int d, real mu, real sigma, data real pwindow) { int dist_id = 1; // lognormal - array[2] real params = {mu, sigma}; + array[2] real params; + params[1] = mu; + params[2] = sigma; int d_upper = d + 1; int primary_id = 1; // Uniform array[0] real primary_params; From cff3176e2c8e88bbe9fa819c514714d769445037 Mon Sep 17 00:00:00 2001 From: athowes Date: Fri, 22 Nov 2024 14:16:43 +0000 Subject: [PATCH 18/62] Move to marginal model name and lint --- R/cohort_model.R | 50 ----------------------------------------- R/marginal_model.R | 55 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 55 insertions(+), 50 deletions(-) delete mode 100644 R/cohort_model.R create mode 100644 R/marginal_model.R diff --git a/R/cohort_model.R b/R/cohort_model.R deleted file mode 100644 index d1d25d91e..000000000 --- a/R/cohort_model.R +++ /dev/null @@ -1,50 +0,0 @@ -#' Prepare cohort model -#' -#' @param data A `data.frame` containing line list data -#' @family cohort_model -#' @export -as_cohort_model <- function(data) { - UseMethod("as_cohort_model") -} - -assert_cohort_model_input <- function(data) { - # ... -} - -#' Prepare cohort model -#' -#' @param data A `data.frame` containing line list data -#' @rdname as_cohort_model -#' @method as_cohort_model data.frame -#' @family cohort_model -#' @autoglobal -#' @export -as_cohort_model.data.frame <- function(data) { - assert_cohort_model_input(data) - class(data) <- c("epidist_cohort_model", class(data)) - data <- data |> - mutate(delay = .data$stime - .data$ptime) - epidist_validate(data) - return(data) -} - -#' Validate cohort model data -#' -#' @param data A `data.frame` containing line list data -#' @param ... ... -#' @method epidist_validate epidist_cohort_model -#' @family cohort_model -#' @export -epidist_validate.epidist_cohort_model <- function(data, ...) { - assert_true(is_cohort_model(data)) - assert_cohort_model_input(data) -} - -#' Check if data has the `epidist_cohort_model` class -#' -#' @param data A `data.frame` containing line list data -#' @family cohort_model -#' @export -is_cohort_model <- function(data) { - inherits(data, "epidist_cohort_model") -} diff --git a/R/marginal_model.R b/R/marginal_model.R new file mode 100644 index 000000000..cae63adf6 --- /dev/null +++ b/R/marginal_model.R @@ -0,0 +1,55 @@ +#' Prepare marginal model to pass through to `brms` +#' +#' @param data A `data.frame` containing line list data +#' @family marginal_model +#' @export +as_epidist_marginal_model <- function(data) { + UseMethod("as_epidist_marginal_model") +} + +#' The marginal model method for `epidist_linelist_data` objects +#' +#' @param data An `epidist_linelist_data` object +#' @method as_epidist_marginal_model epidist_linelist_data +#' @family marginal_model +#' @autoglobal +#' @export +as_epidist_marginal_model.epidist_linelist_data <- function(data) { + assert_epidist(data) + + data <- data |> + mutate(delay = .data$stime_lwr - .data$ptime_lwr) + + data <- new_epidist_marginal_model(data) + assert_epidist(data) + return(data) +} + +#' Class constructor for `epidist_marginal_model` objects +#' +#' @param data A data.frame to convert +#' @returns An object of class `epidist_marginal_model` +#' @family marginal_model +#' @export +new_epidist_marginal_model <- function(data) { + class(data) <- c("epidist_marginal_model", class(data)) + return(data) +} + +#' @method assert_epidist epidist_marginal_model +#' @family marginal_model +#' @export +assert_epidist.epidist_marginal_model <- function(data, ...) { + assert_data_frame(data) + assert_names(names(data), must.include = "delay") + assert_numeric(data$delay, lower = 0) +} + +#' Check if data has the `epidist_marginal_model` class +#' +#' @param data A `data.frame` containing line list data +#' @family marginal_model +#' @export +is_epidist_marginal_model <- function(data) { + inherits(data, "epidist_marginal_model") +} From 0a3e5ae2bfb48f200e5b9c997c2882aae1356b5b Mon Sep 17 00:00:00 2001 From: athowes Date: Fri, 22 Nov 2024 13:18:41 +0000 Subject: [PATCH 19/62] Create aggregate data inside model conversion function for now --- R/marginal_model.R | 8 +++++++- tests/testthat/test-marginal_model.R | 5 +++++ 2 files changed, 12 insertions(+), 1 deletion(-) create mode 100644 tests/testthat/test-marginal_model.R diff --git a/R/marginal_model.R b/R/marginal_model.R index cae63adf6..b54d4c581 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -17,8 +17,14 @@ as_epidist_marginal_model <- function(data) { as_epidist_marginal_model.epidist_linelist_data <- function(data) { assert_epidist(data) + # Here we do the processing to turn an epidist_linelist_data into an aggregate + # dataset. In the future this would be refactored into a function which + # converts from linelist data to aggregate data, and a function which goes + # from aggregate data into the marginal model class data <- data |> - mutate(delay = .data$stime_lwr - .data$ptime_lwr) + mutate(delay = .data$stime_lwr - .data$ptime_lwr) |> + dplyr::group_by(delay) |> + dplyr::summarise(count = dplyr::n()) data <- new_epidist_marginal_model(data) assert_epidist(data) diff --git a/tests/testthat/test-marginal_model.R b/tests/testthat/test-marginal_model.R new file mode 100644 index 000000000..42f8ef731 --- /dev/null +++ b/tests/testthat/test-marginal_model.R @@ -0,0 +1,5 @@ +test_that("as_epidist_marginal_model.epidist_linelist_data with default settings an object with the correct classes", { # nolint: line_length_linter. + prep_obs <- as_epidist_marginal_model(sim_obs) + expect_s3_class(prep_obs, "data.frame") + expect_s3_class(prep_obs, "epidist_latent_model") +}) From a9c5f64da62e2e6b0f691a526c5edd0bd5d16aa0 Mon Sep 17 00:00:00 2001 From: athowes Date: Fri, 22 Nov 2024 14:15:09 +0000 Subject: [PATCH 20/62] First draft on moving marginal_model into functions --- R/marginal_model.R | 63 ++++++++++++++++++++++++++++ tests/testthat/test-marginal_model.R | 42 ++++++++++++++++++- 2 files changed, 104 insertions(+), 1 deletion(-) diff --git a/R/marginal_model.R b/R/marginal_model.R index b54d4c581..a546c5eb4 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -59,3 +59,66 @@ assert_epidist.epidist_marginal_model <- function(data, ...) { is_epidist_marginal_model <- function(data) { inherits(data, "epidist_marginal_model") } + +#' Create the model-specific component of an `epidist` custom family +#' +#' @inheritParams epidist_family_model +#' @param ... Additional arguments passed to method. +#' @method epidist_family_model epidist_marginal_model +#' @family marginal_model +#' @export +epidist_family_model.epidist_marginal_model <- function( + data, family, ...) { + custom_family <- brms::custom_family( + "primarycensored_wrapper", + dpars = family$dpars, + links = c(family$link, family$other_links), + lb = c(NA, as.numeric(lapply(family$other_bounds, "[[", "lb"))), + ub = c(NA, as.numeric(lapply(family$other_bounds, "[[", "ub"))), + type = "int", + loop = TRUE, + vars = "vreal1[n]" + ) + return(custom_family) +} + +#' Define the model-specific component of an `epidist` custom formula +#' +#' @inheritParams epidist_formula_model +#' @param ... Additional arguments passed to method. +#' @method epidist_formula_model epidist_marginal_model +#' @family marginal_model +#' @export +epidist_formula_model.epidist_marginal_model <- function( + data, formula, ...) { + # data is only used to dispatch on + formula <- stats::update( + formula, delay | weights(n) + vreal(pwindow) ~ . + ) + return(formula) +} + +#' @method epidist_stancode epidist_marginal_model +#' @importFrom brms stanvar +#' @family marginal_model +#' @autoglobal +#' @export +epidist_stancode.epidist_marginal_model <- function(data, ...) { + assert_epidist(data) + + stanvars_version <- .version_stanvar() + + stanvars_functions <- brms::stanvar( + block = "functions", + scode = .stan_chunk(file.path("marginal_model", "functions.stan")) + ) + + pcd_stanvars_functions <- brms::stanvar( + block = "functions", + scode = pcd_load_stan_functions() + ) + + stanvars_all <- stanvars_version + stanvars_functions + pcd_stanvars_functions + + return(stanvars_all) +} diff --git a/tests/testthat/test-marginal_model.R b/tests/testthat/test-marginal_model.R index 42f8ef731..c02820cc2 100644 --- a/tests/testthat/test-marginal_model.R +++ b/tests/testthat/test-marginal_model.R @@ -1,5 +1,45 @@ test_that("as_epidist_marginal_model.epidist_linelist_data with default settings an object with the correct classes", { # nolint: line_length_linter. prep_obs <- as_epidist_marginal_model(sim_obs) expect_s3_class(prep_obs, "data.frame") - expect_s3_class(prep_obs, "epidist_latent_model") + expect_s3_class(prep_obs, "epidist_marginal_model") +}) + +test_that("as_epidist_marginal_model.epidist_linelist_data when passed incorrect inputs", { # nolint: line_length_linter. + expect_error(as_epidist_marginal_model(list())) + expect_error(as_epidist_marginal_model(sim_obs[, 1])) +}) + +# Make this data available for other tests +family_lognormal <- epidist_family(prep_obs, family = brms::lognormal()) + +test_that("is_epidist_marginal_model returns TRUE for correct input", { # nolint: line_length_linter. + expect_true(is_epidist_marginal_model(prep_obs)) + expect_true({ + x <- list() + class(x) <- "epidist_marginal_model" + is_epidist_marginal_model(x) + }) +}) + +test_that("is_epidist_marginal_model returns FALSE for incorrect input", { # nolint: line_length_linter. + expect_false(is_epidist_marginal_model(list())) + expect_false({ + x <- list() + class(x) <- "epidist_marginal_model_extension" + is_epidist_marginal_model(x) + }) +}) + +test_that("assert_epidist.epidist_marginal_model doesn't produce an error for correct input", { # nolint: line_length_linter. + expect_no_error(assert_epidist(prep_obs)) +}) + +test_that("assert_epidist.epidist_marginal_model returns FALSE for incorrect input", { # nolint: line_length_linter. + expect_error(assert_epidist(list())) + expect_error(assert_epidist(prep_obs[, 1])) + expect_error({ + x <- list() + class(x) <- "epidist_marginal_model" + assert_epidist(x) + }) }) From b560ecd56c654e6841ec48c29fb43f9962ab399b Mon Sep 17 00:00:00 2001 From: athowes Date: Fri, 22 Nov 2024 14:17:34 +0000 Subject: [PATCH 21/62] Run document --- NAMESPACE | 8 +++++ R/globals.R | 1 + man/as_epidist_marginal_model.Rd | 23 ++++++++++++++ ...st_marginal_model.epidist_linelist_data.Rd | 23 ++++++++++++++ ...ist_family_model.epidist_marginal_model.Rd | 28 +++++++++++++++++ ...st_formula_model.epidist_marginal_model.Rd | 31 +++++++++++++++++++ man/is_epidist_marginal_model.Rd | 23 ++++++++++++++ man/new_epidist_marginal_model.Rd | 26 ++++++++++++++++ 8 files changed, 163 insertions(+) create mode 100644 man/as_epidist_marginal_model.Rd create mode 100644 man/as_epidist_marginal_model.epidist_linelist_data.Rd create mode 100644 man/epidist_family_model.epidist_marginal_model.Rd create mode 100644 man/epidist_formula_model.epidist_marginal_model.Rd create mode 100644 man/is_epidist_marginal_model.Rd create mode 100644 man/new_epidist_marginal_model.Rd diff --git a/NAMESPACE b/NAMESPACE index 93ba3ff12..4c3bb2983 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -6,26 +6,32 @@ S3method(add_mean_sd,lognormal_samples) S3method(as_epidist_latent_model,epidist_linelist_data) S3method(as_epidist_linelist_data,data.frame) S3method(as_epidist_linelist_data,default) +S3method(as_epidist_marginal_model,epidist_linelist_data) S3method(as_epidist_naive_model,epidist_linelist_data) S3method(assert_epidist,default) S3method(assert_epidist,epidist_latent_model) S3method(assert_epidist,epidist_linelist_data) +S3method(assert_epidist,epidist_marginal_model) S3method(assert_epidist,epidist_naive_model) S3method(epidist_family_model,default) S3method(epidist_family_model,epidist_latent_model) S3method(epidist_family_param,default) +S3method(epidist_family_model,epidist_marginal_model) S3method(epidist_family_prior,default) S3method(epidist_family_prior,lognormal) S3method(epidist_formula_model,default) S3method(epidist_formula_model,epidist_latent_model) +S3method(epidist_formula_model,epidist_marginal_model) S3method(epidist_model_prior,default) S3method(epidist_model_prior,epidist_latent_model) S3method(epidist_stancode,default) S3method(epidist_stancode,epidist_latent_model) +S3method(epidist_stancode,epidist_marginal_model) export(Gamma) export(add_mean_sd) export(as_epidist_latent_model) export(as_epidist_linelist_data) +export(as_epidist_marginal_model) export(as_epidist_naive_model) export(assert_epidist) export(bf) @@ -44,10 +50,12 @@ export(epidist_prior) export(epidist_stancode) export(is_epidist_latent_model) export(is_epidist_linelist_data) +export(is_epidist_marginal_model) export(is_epidist_naive_model) export(lognormal) export(new_epidist_latent_model) export(new_epidist_linelist_data) +export(new_epidist_marginal_model) export(new_epidist_naive_model) export(predict_delay_parameters) export(predict_dpar) diff --git a/R/globals.R b/R/globals.R index 6c525d44b..d9fa26abc 100644 --- a/R/globals.R +++ b/R/globals.R @@ -3,6 +3,7 @@ utils::globalVariables(c( "samples", # "woverlap", # + "delay", # "rlnorm", # "fix", # <.replace_prior> "prior_new", # <.replace_prior> diff --git a/man/as_epidist_marginal_model.Rd b/man/as_epidist_marginal_model.Rd new file mode 100644 index 000000000..b0b846022 --- /dev/null +++ b/man/as_epidist_marginal_model.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/marginal_model.R +\name{as_epidist_marginal_model} +\alias{as_epidist_marginal_model} +\title{Prepare marginal model to pass through to \code{brms}} +\usage{ +as_epidist_marginal_model(data) +} +\arguments{ +\item{data}{A \code{data.frame} containing line list data} +} +\description{ +Prepare marginal model to pass through to \code{brms} +} +\seealso{ +Other marginal_model: +\code{\link{as_epidist_marginal_model.epidist_linelist_data}()}, +\code{\link{epidist_family_model.epidist_marginal_model}()}, +\code{\link{epidist_formula_model.epidist_marginal_model}()}, +\code{\link{is_epidist_marginal_model}()}, +\code{\link{new_epidist_marginal_model}()} +} +\concept{marginal_model} diff --git a/man/as_epidist_marginal_model.epidist_linelist_data.Rd b/man/as_epidist_marginal_model.epidist_linelist_data.Rd new file mode 100644 index 000000000..a77576f6d --- /dev/null +++ b/man/as_epidist_marginal_model.epidist_linelist_data.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/marginal_model.R +\name{as_epidist_marginal_model.epidist_linelist_data} +\alias{as_epidist_marginal_model.epidist_linelist_data} +\title{The marginal model method for \code{epidist_linelist_data} objects} +\usage{ +\method{as_epidist_marginal_model}{epidist_linelist_data}(data) +} +\arguments{ +\item{data}{An \code{epidist_linelist_data} object} +} +\description{ +The marginal model method for \code{epidist_linelist_data} objects +} +\seealso{ +Other marginal_model: +\code{\link{as_epidist_marginal_model}()}, +\code{\link{epidist_family_model.epidist_marginal_model}()}, +\code{\link{epidist_formula_model.epidist_marginal_model}()}, +\code{\link{is_epidist_marginal_model}()}, +\code{\link{new_epidist_marginal_model}()} +} +\concept{marginal_model} diff --git a/man/epidist_family_model.epidist_marginal_model.Rd b/man/epidist_family_model.epidist_marginal_model.Rd new file mode 100644 index 000000000..ea6746ee5 --- /dev/null +++ b/man/epidist_family_model.epidist_marginal_model.Rd @@ -0,0 +1,28 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/marginal_model.R +\name{epidist_family_model.epidist_marginal_model} +\alias{epidist_family_model.epidist_marginal_model} +\title{Create the model-specific component of an \code{epidist} custom family} +\usage{ +\method{epidist_family_model}{epidist_marginal_model}(data, family, ...) +} +\arguments{ +\item{data}{An object with class corresponding to an implemented model.} + +\item{family}{Output of a call to \code{brms::brmsfamily()} with additional +information as provided by \code{.add_dpar_info()}} + +\item{...}{Additional arguments passed to method.} +} +\description{ +Create the model-specific component of an \code{epidist} custom family +} +\seealso{ +Other marginal_model: +\code{\link{as_epidist_marginal_model}()}, +\code{\link{as_epidist_marginal_model.epidist_linelist_data}()}, +\code{\link{epidist_formula_model.epidist_marginal_model}()}, +\code{\link{is_epidist_marginal_model}()}, +\code{\link{new_epidist_marginal_model}()} +} +\concept{marginal_model} diff --git a/man/epidist_formula_model.epidist_marginal_model.Rd b/man/epidist_formula_model.epidist_marginal_model.Rd new file mode 100644 index 000000000..94806ff88 --- /dev/null +++ b/man/epidist_formula_model.epidist_marginal_model.Rd @@ -0,0 +1,31 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/marginal_model.R +\name{epidist_formula_model.epidist_marginal_model} +\alias{epidist_formula_model.epidist_marginal_model} +\title{Define the model-specific component of an \code{epidist} custom formula} +\usage{ +\method{epidist_formula_model}{epidist_marginal_model}(data, formula, ...) +} +\arguments{ +\item{data}{An object with class corresponding to an implemented model.} + +\item{formula}{An object of class \link[stats:formula]{stats::formula} or \link[brms:brmsformula]{brms::brmsformula} +(or one that can be coerced to those classes). A symbolic description of the +model to be fitted. A formula must be provided for the distributional +parameter \code{mu}, and may optionally be provided for other distributional +parameters.} + +\item{...}{Additional arguments passed to method.} +} +\description{ +Define the model-specific component of an \code{epidist} custom formula +} +\seealso{ +Other marginal_model: +\code{\link{as_epidist_marginal_model}()}, +\code{\link{as_epidist_marginal_model.epidist_linelist_data}()}, +\code{\link{epidist_family_model.epidist_marginal_model}()}, +\code{\link{is_epidist_marginal_model}()}, +\code{\link{new_epidist_marginal_model}()} +} +\concept{marginal_model} diff --git a/man/is_epidist_marginal_model.Rd b/man/is_epidist_marginal_model.Rd new file mode 100644 index 000000000..5585c2f78 --- /dev/null +++ b/man/is_epidist_marginal_model.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/marginal_model.R +\name{is_epidist_marginal_model} +\alias{is_epidist_marginal_model} +\title{Check if data has the \code{epidist_marginal_model} class} +\usage{ +is_epidist_marginal_model(data) +} +\arguments{ +\item{data}{A \code{data.frame} containing line list data} +} +\description{ +Check if data has the \code{epidist_marginal_model} class +} +\seealso{ +Other marginal_model: +\code{\link{as_epidist_marginal_model}()}, +\code{\link{as_epidist_marginal_model.epidist_linelist_data}()}, +\code{\link{epidist_family_model.epidist_marginal_model}()}, +\code{\link{epidist_formula_model.epidist_marginal_model}()}, +\code{\link{new_epidist_marginal_model}()} +} +\concept{marginal_model} diff --git a/man/new_epidist_marginal_model.Rd b/man/new_epidist_marginal_model.Rd new file mode 100644 index 000000000..f2abe7e32 --- /dev/null +++ b/man/new_epidist_marginal_model.Rd @@ -0,0 +1,26 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/marginal_model.R +\name{new_epidist_marginal_model} +\alias{new_epidist_marginal_model} +\title{Class constructor for \code{epidist_marginal_model} objects} +\usage{ +new_epidist_marginal_model(data) +} +\arguments{ +\item{data}{A data.frame to convert} +} +\value{ +An object of class \code{epidist_marginal_model} +} +\description{ +Class constructor for \code{epidist_marginal_model} objects +} +\seealso{ +Other marginal_model: +\code{\link{as_epidist_marginal_model}()}, +\code{\link{as_epidist_marginal_model.epidist_linelist_data}()}, +\code{\link{epidist_family_model.epidist_marginal_model}()}, +\code{\link{epidist_formula_model.epidist_marginal_model}()}, +\code{\link{is_epidist_marginal_model}()} +} +\concept{marginal_model} From 4d3990bc59b9b6028c550966052ded3f0c49d7b4 Mon Sep 17 00:00:00 2001 From: athowes Date: Fri, 22 Nov 2024 14:49:22 +0000 Subject: [PATCH 22/62] Tests working up to valid Stan code --- R/marginal_model.R | 18 ++++++++++--- tests/testthat/setup.R | 1 + tests/testthat/test-int-marginal_model.R | 34 ++++++++++++++++++++++++ 3 files changed, 49 insertions(+), 4 deletions(-) create mode 100644 tests/testthat/test-int-marginal_model.R diff --git a/R/marginal_model.R b/R/marginal_model.R index a546c5eb4..7e389c9ab 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -22,9 +22,19 @@ as_epidist_marginal_model.epidist_linelist_data <- function(data) { # converts from linelist data to aggregate data, and a function which goes # from aggregate data into the marginal model class data <- data |> - mutate(delay = .data$stime_lwr - .data$ptime_lwr) |> + mutate( + pwindow = ifelse( + .data$stime_lwr < .data$ptime_upr, + .data$stime_upr - .data$ptime_lwr, + .data$ptime_upr - .data$ptime_lwr + ), + delay = .data$stime_lwr - .data$ptime_lwr + ) |> dplyr::group_by(delay) |> - dplyr::summarise(count = dplyr::n()) + dplyr::summarise( + count = dplyr::n(), + pwindow = ifelse(all(pwindow == pwindow[1]), pwindow[1], NA) + ) data <- new_epidist_marginal_model(data) assert_epidist(data) @@ -93,7 +103,7 @@ epidist_formula_model.epidist_marginal_model <- function( data, formula, ...) { # data is only used to dispatch on formula <- stats::update( - formula, delay | weights(n) + vreal(pwindow) ~ . + formula, delay | weights(count) + vreal(pwindow) ~ . ) return(formula) } @@ -115,7 +125,7 @@ epidist_stancode.epidist_marginal_model <- function(data, ...) { pcd_stanvars_functions <- brms::stanvar( block = "functions", - scode = pcd_load_stan_functions() + scode = primarycensored::pcd_load_stan_functions() ) stanvars_all <- stanvars_version + stanvars_functions + pcd_stanvars_functions diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 4cbb14949..03271e4cf 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -116,6 +116,7 @@ sim_obs_sex <- as_epidist_linelist_data( prep_obs <- as_epidist_latent_model(sim_obs) prep_naive_obs <- as_epidist_naive_model(sim_obs) +prep_marginal_obs <- as_epidist_marginal_model(sim_obs) prep_obs_gamma <- as_epidist_latent_model(sim_obs_gamma) prep_obs_sex <- as_epidist_latent_model(sim_obs_sex) diff --git a/tests/testthat/test-int-marginal_model.R b/tests/testthat/test-int-marginal_model.R new file mode 100644 index 000000000..97099b747 --- /dev/null +++ b/tests/testthat/test-int-marginal_model.R @@ -0,0 +1,34 @@ +# Note: some tests in this script are stochastic. As such, test failure may be +# bad luck rather than indicate an issue with the code. However, as these tests +# are reproducible, the distribution of test failures may be investigated by +# varying the input seed. Test failure at an unusually high rate does suggest +# a potential code issue. + +test_that("epidist.epidist_marginal_model Stan code has no syntax errors in the default case", { # nolint: line_length_linter. + skip_on_cran() + stancode <- epidist( + data = prep_marginal_obs, + fn = brms::make_stancode + ) + mod <- cmdstanr::cmdstan_model( + stan_file = cmdstanr::write_stan_file(stancode), compile = FALSE + ) + expect_true(mod$check_syntax()) +}) + +test_that("epidist.epidist_marginal_model fits and the MCMC converges in the default case", { # nolint: line_length_linter. + # Note: this test is stochastic. See note at the top of this script + skip_on_cran() + set.seed(1) + fit <- epidist( + data = prep_marginal_obs, + seed = 1, + silent = 2, refresh = 0, + cores = 2, + chains = 2, + backend = "cmdstanr" + ) + expect_s3_class(fit, "brmsfit") + expect_s3_class(fit, "epidist_fit") + expect_convergence(fit) +}) From e5c7f969e48f3cffd6981d27f0d66354b82a695d Mon Sep 17 00:00:00 2001 From: athowes Date: Mon, 25 Nov 2024 12:35:43 +0000 Subject: [PATCH 23/62] Regex version of marginal model --- R/marginal_model.R | 59 ++++++++++++++++++++++++- inst/stan/marginal_model/functions.stan | 18 +++++--- 2 files changed, 70 insertions(+), 7 deletions(-) diff --git a/R/marginal_model.R b/R/marginal_model.R index 7e389c9ab..4a80e6b52 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -80,7 +80,7 @@ is_epidist_marginal_model <- function(data) { epidist_family_model.epidist_marginal_model <- function( data, family, ...) { custom_family <- brms::custom_family( - "primarycensored_wrapper", + paste0("marginal_", family$family), dpars = family$dpars, links = c(family$link, family$other_links), lb = c(NA, as.numeric(lapply(family$other_bounds, "[[", "lb"))), @@ -113,7 +113,10 @@ epidist_formula_model.epidist_marginal_model <- function( #' @family marginal_model #' @autoglobal #' @export -epidist_stancode.epidist_marginal_model <- function(data, ...) { +epidist_stancode.epidist_marginal_model <- function( + data, + family = epidist_family(data), + formula = epidist_formula(data), ...) { assert_epidist(data) stanvars_version <- .version_stanvar() @@ -123,6 +126,58 @@ epidist_stancode.epidist_marginal_model <- function(data, ...) { scode = .stan_chunk(file.path("marginal_model", "functions.stan")) ) + family_name <- gsub("marginal_", "", family$name, fixed = TRUE) + + stanvars_functions[[1]]$scode <- gsub( + "family", family_name, stanvars_functions[[1]]$scode, + fixed = TRUE + ) + + # Can probably be extended to non-analytic solution families but for now + if (family_name == "lognormal") { + dist_id <- 1 + } else if (family_name == "gamma") { + dist_id <- 2 + } else if (family_name == "weibell") { + dist_id <- 3 + } else { + cli_abort(c( + "!" = "No analytic solution available in primarycensored for this family" + )) + } + + # Replace the dist_id passed to primarycensored + stanvars_functions[[1]]$scode <- gsub( + "input_dist_id", dist_id, stanvars_functions[[1]]$scode, + fixed = TRUE + ) + + # Inject vector or real depending if there is a model for each dpar + vector_real <- purrr::map_vec(family$dpars, function(dpar) { + return("real") + }) + + stanvars_functions[[1]]$scode <- gsub( + "dpars_A", + toString(paste0(vector_real, " ", family$dpars)), + stanvars_functions[[1]]$scode, + fixed = TRUE + ) + + # Need to consider whether any reparametrisation is required here for input + # input primarycensored. Assume not for now. Also assume two dpars + stanvars_functions[[1]]$scode <- gsub( + "dpars_1", family$dpars[1], + stanvars_functions[[1]]$scode, + fixed = TRUE + ) + + stanvars_functions[[1]]$scode <- gsub( + "dpars_2", family$dpars[2], + stanvars_functions[[1]]$scode, + fixed = TRUE + ) + pcd_stanvars_functions <- brms::stanvar( block = "functions", scode = primarycensored::pcd_load_stan_functions() diff --git a/inst/stan/marginal_model/functions.stan b/inst/stan/marginal_model/functions.stan index fc7e9ee92..5df25b5b9 100644 --- a/inst/stan/marginal_model/functions.stan +++ b/inst/stan/marginal_model/functions.stan @@ -1,10 +1,18 @@ -real primarycensored_wrapper_lpmf(data int d, real mu, real sigma, data real pwindow) { - int dist_id = 1; // lognormal +// This function is a wrapper to primarycensored_lpmf +// Here the strings +// * family +// * dpars_A +// * dpars_1 +// * dpars_2 +// are/have been replaced using regex + +real marginal_family_lpmf(data int d, dpars_A, data real pwindow) { + int dist_id = input_dist_id; array[2] real params; - params[1] = mu; - params[2] = sigma; + params[1] = dpars_1; + params[2] = dpars_2; int d_upper = d + 1; - int primary_id = 1; // Uniform + int primary_id = 1; // Fixed as uniform array[0] real primary_params; return primarycensored_lpmf(d | dist_id, params, pwindow, d_upper, positive_infinity(), primary_id, primary_params); } From 29b0e2ce22be5aa196c86f31382f9503c62c75ae Mon Sep 17 00:00:00 2001 From: athowes Date: Mon, 25 Nov 2024 13:52:53 +0000 Subject: [PATCH 24/62] Use prep_marginal_obs --- tests/testthat/test-marginal_model.R | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/testthat/test-marginal_model.R b/tests/testthat/test-marginal_model.R index c02820cc2..34085284b 100644 --- a/tests/testthat/test-marginal_model.R +++ b/tests/testthat/test-marginal_model.R @@ -1,7 +1,7 @@ test_that("as_epidist_marginal_model.epidist_linelist_data with default settings an object with the correct classes", { # nolint: line_length_linter. - prep_obs <- as_epidist_marginal_model(sim_obs) - expect_s3_class(prep_obs, "data.frame") - expect_s3_class(prep_obs, "epidist_marginal_model") + prep_marginal_obs <- as_epidist_marginal_model(sim_obs) + expect_s3_class(prep_marginal_obs, "data.frame") + expect_s3_class(prep_marginal_obs, "epidist_marginal_model") }) test_that("as_epidist_marginal_model.epidist_linelist_data when passed incorrect inputs", { # nolint: line_length_linter. @@ -10,10 +10,13 @@ test_that("as_epidist_marginal_model.epidist_linelist_data when passed incorrect }) # Make this data available for other tests -family_lognormal <- epidist_family(prep_obs, family = brms::lognormal()) +family_lognormal <- epidist_family( + prep_marginal_obs, + family = brms::lognormal() +) test_that("is_epidist_marginal_model returns TRUE for correct input", { # nolint: line_length_linter. - expect_true(is_epidist_marginal_model(prep_obs)) + expect_true(is_epidist_marginal_model(prep_marginal_obs)) expect_true({ x <- list() class(x) <- "epidist_marginal_model" @@ -31,12 +34,12 @@ test_that("is_epidist_marginal_model returns FALSE for incorrect input", { # nol }) test_that("assert_epidist.epidist_marginal_model doesn't produce an error for correct input", { # nolint: line_length_linter. - expect_no_error(assert_epidist(prep_obs)) + expect_no_error(assert_epidist(prep_marginal_obs)) }) test_that("assert_epidist.epidist_marginal_model returns FALSE for incorrect input", { # nolint: line_length_linter. expect_error(assert_epidist(list())) - expect_error(assert_epidist(prep_obs[, 1])) + expect_error(assert_epidist(prep_marginal_obs[, 1])) expect_error({ x <- list() class(x) <- "epidist_marginal_model" From 08f477e05d541c5c598d3fc476b4ff350a978e95 Mon Sep 17 00:00:00 2001 From: athowes Date: Mon, 25 Nov 2024 14:39:43 +0000 Subject: [PATCH 25/62] Improve assert for marginal model --- R/marginal_model.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/marginal_model.R b/R/marginal_model.R index 4a80e6b52..df035af6b 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -57,7 +57,7 @@ new_epidist_marginal_model <- function(data) { #' @export assert_epidist.epidist_marginal_model <- function(data, ...) { assert_data_frame(data) - assert_names(names(data), must.include = "delay") + assert_names(names(data), must.include = c("delay", "count", "pwindow")) assert_numeric(data$delay, lower = 0) } From 944c455acc408ca71286e5c9c1c4864eb685a89a Mon Sep 17 00:00:00 2001 From: athowes Date: Mon, 25 Nov 2024 14:44:56 +0000 Subject: [PATCH 26/62] Add pkgdown and document --- R/globals.R | 1 + _pkgdown.yml | 4 ++++ 2 files changed, 5 insertions(+) diff --git a/R/globals.R b/R/globals.R index d9fa26abc..ef9eea43a 100644 --- a/R/globals.R +++ b/R/globals.R @@ -4,6 +4,7 @@ utils::globalVariables(c( "samples", # "woverlap", # "delay", # + "pwindow", # "rlnorm", # "fix", # <.replace_prior> "prior_new", # <.replace_prior> diff --git a/_pkgdown.yml b/_pkgdown.yml index 51ef0149f..83a47222d 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -45,6 +45,10 @@ reference: desc: Specific methods for the latent model contents: - has_concept("latent_model") +- title: Marginal model + desc: Specific methods for the marginal model + contents: + - has_concept("marginal_model") - title: Postprocess desc: Functions for postprocessing model output contents: From d545f82409a0d95e8b5efa091da75c246f8456e7 Mon Sep 17 00:00:00 2001 From: athowes Date: Wed, 27 Nov 2024 10:13:26 +0000 Subject: [PATCH 27/62] Clean up scratch implementation --- inst/marginal_model-scratch.R | 51 ++++++++++++++++++----- inst/stan/marginal_model/data.stan | 2 - inst/stan/marginal_model/tparameters.stan | 3 -- 3 files changed, 40 insertions(+), 16 deletions(-) delete mode 100644 inst/stan/marginal_model/data.stan delete mode 100644 inst/stan/marginal_model/tparameters.stan diff --git a/inst/marginal_model-scratch.R b/inst/marginal_model-scratch.R index b73355a64..0af43d512 100644 --- a/inst/marginal_model-scratch.R +++ b/inst/marginal_model-scratch.R @@ -18,21 +18,29 @@ sim_obs <- simulate_gillespie() |> meanlog = meanlog, sdlog = sdlog ) |> - observe_process() |> + mutate( + ptime_lwr = floor(.data$ptime), + ptime_upr = .data$ptime_lwr + 1, + stime_lwr = floor(.data$stime), + stime_upr = .data$stime_lwr + 1, + obs_time = obs_time, + delay = stime_lwr - ptime_lwr + ) |> + filter(.data$stime_upr <= .data$obs_time) |> filter(.data$stime_upr <= obs_time) |> dplyr::slice_sample(n = sample_size, replace = FALSE) # Create cohort version of data cohort_obs <- sim_obs |> - group_by(delay = delay_daily) |> + group_by(delay) |> summarise(n = n()) ggplot(cohort_obs, aes(x = delay, y = n)) + geom_col() fit_direct <- brms::brm( - formula = delay_daily ~ 1, + formula = delay ~ 1, family = "lognormal", data = sim_obs ) @@ -50,7 +58,7 @@ summary(fit_direct_weighted) lognormal <- brms::lognormal() primarycensored_family <- brms::custom_family( - "primarycensored_wrapper", + "marginal_lognormal", dpars = lognormal$dpar, links = c(lognormal$link, lognormal$link_sigma), type = "int", @@ -72,13 +80,34 @@ stanvars_functions <- brms::stanvar( scode = .stan_chunk(file.path("marginal_model", "functions.stan")) ) -# pwindow <- data$pwindow -# -# stanvars_data <- brms::stanvar( -# x = pwindow, -# block = "data", -# scode = .stan_chunk("marginal_model/data.stan") -# ) + +stanvars_functions[[1]]$scode <- gsub( + "family", "lognormal", stanvars_functions[[1]]$scode, + fixed = TRUE +) + +stanvars_functions[[1]]$scode <- gsub( + "input_dist_id", 1, stanvars_functions[[1]]$scode, + fixed = TRUE +) + +stanvars_functions[[1]]$scode <- gsub( + "dpars_A", "real mu, real sigma", + stanvars_functions[[1]]$scode, + fixed = TRUE +) + +stanvars_functions[[1]]$scode <- gsub( + "dpars_1", "mu", + stanvars_functions[[1]]$scode, + fixed = TRUE +) + +stanvars_functions[[1]]$scode <- gsub( + "dpars_2", "sigma", + stanvars_functions[[1]]$scode, + fixed = TRUE +) stanvars_all <- pcd_stanvars_functions + stanvars_functions diff --git a/inst/stan/marginal_model/data.stan b/inst/stan/marginal_model/data.stan deleted file mode 100644 index c170a774e..000000000 --- a/inst/stan/marginal_model/data.stan +++ /dev/null @@ -1,2 +0,0 @@ -vector[N] pwindow; -pwindow = vreal2; diff --git a/inst/stan/marginal_model/tparameters.stan b/inst/stan/marginal_model/tparameters.stan deleted file mode 100644 index f0f102c40..000000000 --- a/inst/stan/marginal_model/tparameters.stan +++ /dev/null @@ -1,3 +0,0 @@ -vector[2] params; -params[1] = mu; -params[2] = sigma; From 2ef2b592fdbc086ee56bb647eb76779188f8087e Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 29 Nov 2024 11:20:37 +0000 Subject: [PATCH 28/62] remove scratch file --- inst/marginal_model-scratch.R | 127 ---------------------------------- 1 file changed, 127 deletions(-) delete mode 100644 inst/marginal_model-scratch.R diff --git a/inst/marginal_model-scratch.R b/inst/marginal_model-scratch.R deleted file mode 100644 index 0af43d512..000000000 --- a/inst/marginal_model-scratch.R +++ /dev/null @@ -1,127 +0,0 @@ -library(dplyr) -library(ggplot2) -library(brms) -library(primarycensored) -library(bayesplot) - -set.seed(101) - -obs_time <- 25 -sample_size <- 500 - -meanlog <- 1.8 -sdlog <- 0.5 - -sim_obs <- simulate_gillespie() |> - simulate_secondary( - dist = rlnorm, - meanlog = meanlog, - sdlog = sdlog - ) |> - mutate( - ptime_lwr = floor(.data$ptime), - ptime_upr = .data$ptime_lwr + 1, - stime_lwr = floor(.data$stime), - stime_upr = .data$stime_lwr + 1, - obs_time = obs_time, - delay = stime_lwr - ptime_lwr - ) |> - filter(.data$stime_upr <= .data$obs_time) |> - filter(.data$stime_upr <= obs_time) |> - dplyr::slice_sample(n = sample_size, replace = FALSE) - -# Create cohort version of data - -cohort_obs <- sim_obs |> - group_by(delay) |> - summarise(n = n()) - -ggplot(cohort_obs, aes(x = delay, y = n)) + - geom_col() - -fit_direct <- brms::brm( - formula = delay ~ 1, - family = "lognormal", - data = sim_obs -) - -summary(fit_direct) - -fit_direct_weighted <- brms::brm( - formula = delay | weights(n) ~ 1, - family = "lognormal", - cohort_obs, -) - -summary(fit_direct_weighted) - -lognormal <- brms::lognormal() - -primarycensored_family <- brms::custom_family( - "marginal_lognormal", - dpars = lognormal$dpar, - links = c(lognormal$link, lognormal$link_sigma), - type = "int", - loop = TRUE, - vars = c("vreal1[n]") -) - -data <- cohort_obs |> - select(d = delay, n = n) |> - mutate(pwindow = 1) - -pcd_stanvars_functions <- brms::stanvar( - block = "functions", - scode = pcd_load_stan_functions() -) - -stanvars_functions <- brms::stanvar( - block = "functions", - scode = .stan_chunk(file.path("marginal_model", "functions.stan")) -) - - -stanvars_functions[[1]]$scode <- gsub( - "family", "lognormal", stanvars_functions[[1]]$scode, - fixed = TRUE -) - -stanvars_functions[[1]]$scode <- gsub( - "input_dist_id", 1, stanvars_functions[[1]]$scode, - fixed = TRUE -) - -stanvars_functions[[1]]$scode <- gsub( - "dpars_A", "real mu, real sigma", - stanvars_functions[[1]]$scode, - fixed = TRUE -) - -stanvars_functions[[1]]$scode <- gsub( - "dpars_1", "mu", - stanvars_functions[[1]]$scode, - fixed = TRUE -) - -stanvars_functions[[1]]$scode <- gsub( - "dpars_2", "sigma", - stanvars_functions[[1]]$scode, - fixed = TRUE -) - -stanvars_all <- pcd_stanvars_functions + stanvars_functions - -stancode <- brms::make_stancode( - formula = d | weights(n) + vreal(pwindow) ~ 1, - family = primarycensored_family, - data = data, - stanvars = stanvars_all, -) - -fit_pcd <- brms::brm( - formula = d | weights(n) + vreal(pwindow) ~ 1, - family = primarycensored_family, - data = data, - stanvars = stanvars_all, - backend = "cmdstanr" -) From cee62accb1799907dae289b87093209d128f364e Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 29 Nov 2024 15:33:07 +0000 Subject: [PATCH 29/62] update data format, formula, and family --- R/marginal_model.R | 48 ++++++++++++++----------- inst/stan/marginal_model/functions.stan | 43 +++++++++++++--------- 2 files changed, 55 insertions(+), 36 deletions(-) diff --git a/R/marginal_model.R b/R/marginal_model.R index df035af6b..86ac79e21 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -17,23 +17,14 @@ as_epidist_marginal_model <- function(data) { as_epidist_marginal_model.epidist_linelist_data <- function(data) { assert_epidist(data) - # Here we do the processing to turn an epidist_linelist_data into an aggregate - # dataset. In the future this would be refactored into a function which - # converts from linelist data to aggregate data, and a function which goes - # from aggregate data into the marginal model class data <- data |> mutate( - pwindow = ifelse( - .data$stime_lwr < .data$ptime_upr, - .data$stime_upr - .data$ptime_lwr, - .data$ptime_upr - .data$ptime_lwr - ), - delay = .data$stime_lwr - .data$ptime_lwr - ) |> - dplyr::group_by(delay) |> - dplyr::summarise( - count = dplyr::n(), - pwindow = ifelse(all(pwindow == pwindow[1]), pwindow[1], NA) + pwindow = .data$ptime_upr - .data$ptime_lwr, + swindow = .data$stime_upr - .data$stime_lwr, + relative_obs_time = .data$obs_time - .data$ptime_lwr, + delay_lwr = .data$stime_lwr - .data$ptime_lwr, + delay_upr = .data$stime_upr - .data$ptime_lwr, + n = 1 ) data <- new_epidist_marginal_model(data) @@ -57,8 +48,19 @@ new_epidist_marginal_model <- function(data) { #' @export assert_epidist.epidist_marginal_model <- function(data, ...) { assert_data_frame(data) - assert_names(names(data), must.include = c("delay", "count", "pwindow")) - assert_numeric(data$delay, lower = 0) + assert_names(names(data), must.include = c( + "ptime_lwr", "ptime_upr", "stime_lwr", "stime_upr", "obs_time", + "pwindow", "swindow", "delay_lwr", "delay_upr", "n" + )) + assert_numeric(data$pwindow, lower = 0) + assert_numeric(data$swindow, lower = 0) + assert_numeric(data$delay_lwr) + assert_numeric(data$delay_upr) + assert_true( + all(abs(data$delay_upr - (data$delay_lwr + data$swindow)) < 1e-10), + "delay_upr must equal delay_lwr + swindow" + ) + assert_numeric(data$n, lower = 0) } #' Check if data has the `epidist_marginal_model` class @@ -85,9 +87,14 @@ epidist_family_model.epidist_marginal_model <- function( links = c(family$link, family$other_links), lb = c(NA, as.numeric(lapply(family$other_bounds, "[[", "lb"))), ub = c(NA, as.numeric(lapply(family$other_bounds, "[[", "ub"))), - type = "int", + type = family$type, + vars = c( + "vreal1[n]", "vreal2[n]", "vreal3[n]", "vreal4[n]" + ), loop = TRUE, - vars = "vreal1[n]" + log_lik = epidist_gen_log_lik(family), + posterior_predict = epidist_gen_posterior_predict(family), + posterior_epred = epidist_gen_posterior_epred(family) ) return(custom_family) } @@ -103,7 +110,8 @@ epidist_formula_model.epidist_marginal_model <- function( data, formula, ...) { # data is only used to dispatch on formula <- stats::update( - formula, delay | weights(count) + vreal(pwindow) ~ . + formula, delay_lwr | weights(count) + + vreal(delay_upr, relative_obs_time, pwindow, swindow) ~ . ) return(formula) } diff --git a/inst/stan/marginal_model/functions.stan b/inst/stan/marginal_model/functions.stan index 5df25b5b9..801606519 100644 --- a/inst/stan/marginal_model/functions.stan +++ b/inst/stan/marginal_model/functions.stan @@ -1,18 +1,29 @@ -// This function is a wrapper to primarycensored_lpmf -// Here the strings -// * family -// * dpars_A -// * dpars_1 -// * dpars_2 -// are/have been replaced using regex +/** + * Compute the log probability density function for a marginal model with censoring + * + * This function is designed to be read into R where: + * - 'family' is replaced with the target distribution (e.g., 'lognormal') + * - 'dpars_A' is replaced with multiple parameters in the format + * "vector|real paramname1, vector|real paramname2, ..." depending on whether + * each parameter has a model. This includes distribution parameters. + * - 'dpars_B' is replaced with the same parameters as dpars_A but with window + * indices removed. + * + * @param y Real value of observed delay + * @param dpars_A Distribution parameters (replaced via regex) + * @param y_upper Upper bound of delay interval + * @param relative_obs_t Observation time relative to primary window start + * @param pwindow_width Primary window width (actual time scale) + * @param swindow_width Secondary window width (actual time scale) + * + * @return Log probability density with censoring adjustment for marginal model + */ + real marginal_family_lpdf(data real y, dpars_A, data real y_upper, + data real relative_obs_t, data real pwindow_width, + data real swindow_width) { -real marginal_family_lpmf(data int d, dpars_A, data real pwindow) { - int dist_id = input_dist_id; - array[2] real params; - params[1] = dpars_1; - params[2] = dpars_2; - int d_upper = d + 1; - int primary_id = 1; // Fixed as uniform - array[0] real primary_params; - return primarycensored_lpmf(d | dist_id, params, pwindow, d_upper, positive_infinity(), primary_id, primary_params); + return primarycensored_lpmf( + y | dist_id, {dpars_B}, pwindow, y_upper, relative_obs_t, + primary_id, {primary_params} + ); } From cb321952fc8067617e478fdef1433efcb7f56113 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 29 Nov 2024 15:37:58 +0000 Subject: [PATCH 30/62] update stan code --- R/marginal_model.R | 33 +++++++++++++++++---------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/R/marginal_model.R b/R/marginal_model.R index 86ac79e21..effc446f4 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -49,18 +49,23 @@ new_epidist_marginal_model <- function(data) { assert_epidist.epidist_marginal_model <- function(data, ...) { assert_data_frame(data) assert_names(names(data), must.include = c( - "ptime_lwr", "ptime_upr", "stime_lwr", "stime_upr", "obs_time", - "pwindow", "swindow", "delay_lwr", "delay_upr", "n" + "pwindow", "swindow", "delay_lwr", "delay_upr", "n", + "relative_obs_time" )) assert_numeric(data$pwindow, lower = 0) assert_numeric(data$swindow, lower = 0) assert_numeric(data$delay_lwr) assert_numeric(data$delay_upr) + assert_numeric(data$relative_obs_time) assert_true( all(abs(data$delay_upr - (data$delay_lwr + data$swindow)) < 1e-10), "delay_upr must equal delay_lwr + swindow" ) - assert_numeric(data$n, lower = 0) + assert_true( + all(data$relative_obs_time >= data$delay_upr), + "relative_obs_time must be greater than or equal to delay_upr" + ) + assert_numeric(data$n, lower = 1) } #' Check if data has the `epidist_marginal_model` class @@ -110,7 +115,7 @@ epidist_formula_model.epidist_marginal_model <- function( data, formula, ...) { # data is only used to dispatch on formula <- stats::update( - formula, delay_lwr | weights(count) + + formula, delay_lwr | weights(n) + vreal(delay_upr, relative_obs_time, pwindow, swindow) ~ . ) return(formula) @@ -156,15 +161,10 @@ epidist_stancode.epidist_marginal_model <- function( # Replace the dist_id passed to primarycensored stanvars_functions[[1]]$scode <- gsub( - "input_dist_id", dist_id, stanvars_functions[[1]]$scode, + "dist_id", dist_id, stanvars_functions[[1]]$scode, fixed = TRUE ) - # Inject vector or real depending if there is a model for each dpar - vector_real <- purrr::map_vec(family$dpars, function(dpar) { - return("real") - }) - stanvars_functions[[1]]$scode <- gsub( "dpars_A", toString(paste0(vector_real, " ", family$dpars)), @@ -172,17 +172,18 @@ epidist_stancode.epidist_marginal_model <- function( fixed = TRUE ) - # Need to consider whether any reparametrisation is required here for input - # input primarycensored. Assume not for now. Also assume two dpars stanvars_functions[[1]]$scode <- gsub( - "dpars_1", family$dpars[1], - stanvars_functions[[1]]$scode, + "dpars_B", family$param, stanvars_functions[[1]]$scode, fixed = TRUE ) stanvars_functions[[1]]$scode <- gsub( - "dpars_2", family$dpars[2], - stanvars_functions[[1]]$scode, + "primary_id", "1", stanvars_functions[[1]]$scode, + fixed = TRUE + ) + + stanvars_functions[[1]]$scode <- gsub( + "primary_params", "", stanvars_functions[[1]]$scode, fixed = TRUE ) From 67cd715463f7a7d2a15e8f09c11880c7970ab42d Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 29 Nov 2024 16:47:40 +0000 Subject: [PATCH 31/62] basic working version --- NAMESPACE | 2 +- R/globals.R | 2 -- R/marginal_model.R | 40 +++++++++++++------------ inst/stan/marginal_model/functions.stan | 14 +++++---- vignettes/epidist.Rmd | 2 +- 5 files changed, 31 insertions(+), 29 deletions(-) diff --git a/NAMESPACE b/NAMESPACE index 4c3bb2983..beaa4ea64 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -15,8 +15,8 @@ S3method(assert_epidist,epidist_marginal_model) S3method(assert_epidist,epidist_naive_model) S3method(epidist_family_model,default) S3method(epidist_family_model,epidist_latent_model) -S3method(epidist_family_param,default) S3method(epidist_family_model,epidist_marginal_model) +S3method(epidist_family_param,default) S3method(epidist_family_prior,default) S3method(epidist_family_prior,lognormal) S3method(epidist_formula_model,default) diff --git a/R/globals.R b/R/globals.R index ef9eea43a..6c525d44b 100644 --- a/R/globals.R +++ b/R/globals.R @@ -3,8 +3,6 @@ utils::globalVariables(c( "samples", # "woverlap", # - "delay", # - "pwindow", # "rlnorm", # "fix", # <.replace_prior> "prior_new", # <.replace_prior> diff --git a/R/marginal_model.R b/R/marginal_model.R index effc446f4..6bfe5843b 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -54,17 +54,19 @@ assert_epidist.epidist_marginal_model <- function(data, ...) { )) assert_numeric(data$pwindow, lower = 0) assert_numeric(data$swindow, lower = 0) - assert_numeric(data$delay_lwr) - assert_numeric(data$delay_upr) + assert_integerish(data$delay_lwr) + assert_integerish(data$delay_upr) assert_numeric(data$relative_obs_time) - assert_true( - all(abs(data$delay_upr - (data$delay_lwr + data$swindow)) < 1e-10), - "delay_upr must equal delay_lwr + swindow" - ) - assert_true( - all(data$relative_obs_time >= data$delay_upr), - "relative_obs_time must be greater than or equal to delay_upr" - ) + if (!all(abs(data$delay_upr - (data$delay_lwr + data$swindow)) < 1e-10)) { + cli::cli_abort( + "delay_upr must equal delay_lwr + swindow" + ) + } + if (!all(data$relative_obs_time >= data$delay_upr)) { + cli::cli_abort( + "relative_obs_time must be greater than or equal to delay_upr" + ) + } assert_numeric(data$n, lower = 1) } @@ -92,9 +94,9 @@ epidist_family_model.epidist_marginal_model <- function( links = c(family$link, family$other_links), lb = c(NA, as.numeric(lapply(family$other_bounds, "[[", "lb"))), ub = c(NA, as.numeric(lapply(family$other_bounds, "[[", "ub"))), - type = family$type, + type = "int", vars = c( - "vreal1[n]", "vreal2[n]", "vreal3[n]", "vreal4[n]" + "vreal1[n]", "vreal2[n]", "vreal3[n]", "vreal4[n]", "primary_params" ), loop = TRUE, log_lik = epidist_gen_log_lik(family), @@ -146,7 +148,6 @@ epidist_stancode.epidist_marginal_model <- function( fixed = TRUE ) - # Can probably be extended to non-analytic solution families but for now if (family_name == "lognormal") { dist_id <- 1 } else if (family_name == "gamma") { @@ -155,7 +156,7 @@ epidist_stancode.epidist_marginal_model <- function( dist_id <- 3 } else { cli_abort(c( - "!" = "No analytic solution available in primarycensored for this family" + "!" = "No solution available in primarycensored for this family" )) } @@ -167,7 +168,7 @@ epidist_stancode.epidist_marginal_model <- function( stanvars_functions[[1]]$scode <- gsub( "dpars_A", - toString(paste0(vector_real, " ", family$dpars)), + toString(paste0("real ", family$dpars)), stanvars_functions[[1]]$scode, fixed = TRUE ) @@ -182,9 +183,9 @@ epidist_stancode.epidist_marginal_model <- function( fixed = TRUE ) - stanvars_functions[[1]]$scode <- gsub( - "primary_params", "", stanvars_functions[[1]]$scode, - fixed = TRUE + stanvars_parameters <- brms::stanvar( + block = "parameters", + scode = "array[0] real primary_params;" ) pcd_stanvars_functions <- brms::stanvar( @@ -192,7 +193,8 @@ epidist_stancode.epidist_marginal_model <- function( scode = primarycensored::pcd_load_stan_functions() ) - stanvars_all <- stanvars_version + stanvars_functions + pcd_stanvars_functions + stanvars_all <- stanvars_version + stanvars_functions + + pcd_stanvars_functions + stanvars_parameters return(stanvars_all) } diff --git a/inst/stan/marginal_model/functions.stan b/inst/stan/marginal_model/functions.stan index 801606519..1985413c5 100644 --- a/inst/stan/marginal_model/functions.stan +++ b/inst/stan/marginal_model/functions.stan @@ -1,5 +1,5 @@ /** - * Compute the log probability density function for a marginal model with censoring + * Compute the log probability mass function for a marginal model with censoring * * This function is designed to be read into R where: * - 'family' is replaced with the target distribution (e.g., 'lognormal') @@ -15,15 +15,17 @@ * @param relative_obs_t Observation time relative to primary window start * @param pwindow_width Primary window width (actual time scale) * @param swindow_width Secondary window width (actual time scale) + * @param primary_params Array of parameters for primary distribution * - * @return Log probability density with censoring adjustment for marginal model + * @return Log probability mass with censoring adjustment for marginal model */ - real marginal_family_lpdf(data real y, dpars_A, data real y_upper, + real marginal_family_lpmf(data int y, dpars_A, data real y_upper, data real relative_obs_t, data real pwindow_width, - data real swindow_width) { + data real swindow_width, + array[] real primary_params) { return primarycensored_lpmf( - y | dist_id, {dpars_B}, pwindow, y_upper, relative_obs_t, - primary_id, {primary_params} + y | dist_id, {dpars_B}, pwindow_width, y_upper, relative_obs_t, + primary_id, primary_params ); } diff --git a/vignettes/epidist.Rmd b/vignettes/epidist.Rmd index e7c8e807c..4dee7b343 100644 --- a/vignettes/epidist.Rmd +++ b/vignettes/epidist.Rmd @@ -274,7 +274,7 @@ linelist_data <- as_epidist_linelist_data( obs_time = obs_cens_trunc_samp$obs_time ) -data <- as_epidist_latent_model(linelist_data) +data <- as_epidist_marginal_model(linelist_data) class(data) ``` From 371ed9c42954b7ee9adcee7d4c23c28a144a3194 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 29 Nov 2024 17:00:21 +0000 Subject: [PATCH 32/62] add transformm data methods --- NAMESPACE | 3 ++ R/epidist.R | 8 +++-- R/transform_data.R | 37 +++++++++++++++++++++ _pkgdown.yml | 4 +++ man/epidist_transform_data.Rd | 30 +++++++++++++++++ man/epidist_transform_data_model.Rd | 27 +++++++++++++++ man/epidist_transform_data_model.default.Rd | 27 +++++++++++++++ 7 files changed, 133 insertions(+), 3 deletions(-) create mode 100644 R/transform_data.R create mode 100644 man/epidist_transform_data.Rd create mode 100644 man/epidist_transform_data_model.Rd create mode 100644 man/epidist_transform_data_model.default.Rd diff --git a/NAMESPACE b/NAMESPACE index beaa4ea64..9d2ab4934 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -27,6 +27,7 @@ S3method(epidist_model_prior,epidist_latent_model) S3method(epidist_stancode,default) S3method(epidist_stancode,epidist_latent_model) S3method(epidist_stancode,epidist_marginal_model) +S3method(epidist_transform_data_model,default) export(Gamma) export(add_mean_sd) export(as_epidist_latent_model) @@ -48,6 +49,8 @@ export(epidist_gen_posterior_predict) export(epidist_model_prior) export(epidist_prior) export(epidist_stancode) +export(epidist_transform_data) +export(epidist_transform_data_model) export(is_epidist_latent_model) export(is_epidist_linelist_data) export(is_epidist_marginal_model) diff --git a/R/epidist.R b/R/epidist.R index 78db416a8..8825edea9 100644 --- a/R/epidist.R +++ b/R/epidist.R @@ -38,16 +38,18 @@ epidist <- function(data, formula = mu ~ 1, epidist_formula <- epidist_formula( data = data, family = epidist_family, formula = formula ) + trans_data <- epidist_transform_data(data, epidist_family, epidist_formula) epidist_prior <- epidist_prior( - data = data, family = epidist_family, formula = epidist_formula, prior, + data = trans_data, family = epidist_family, + formula = epidist_formula, prior, merge = merge_priors ) epidist_stancode <- epidist_stancode( - data = data, family = epidist_family, formula = epidist_formula + data = trans_data, family = epidist_family, formula = epidist_formula ) fit <- fn( formula = epidist_formula, family = epidist_family, prior = epidist_prior, - stanvars = epidist_stancode, data = data, ... + stanvars = epidist_stancode, data = trans_data, ... ) class(fit) <- c(class(fit), "epidist_fit") return(fit) diff --git a/R/transform_data.R b/R/transform_data.R new file mode 100644 index 000000000..a1acf50bc --- /dev/null +++ b/R/transform_data.R @@ -0,0 +1,37 @@ +#' Transform data for an epidist model +#' +#' This function is used within [epidist()] to transform data before passing to +#' `brms`. It is unlikely that as a user you will need this function, but we +#' export it nonetheless to be transparent about what happens inside of a call +#' to [epidist()]. +#' +#' @inheritParams epidist +#' @param family A description of the response distribution and link function to +#' be used in the model created using [epidist_family()]. +#' @param formula A formula object created using [epidist_formula()]. +#' @family transform +#' @export +epidist_transform_data <- function(data, family, formula, ...) { + assert_epidist(data) + data <- epidist_transform_data_model(data, family, formula) + return(data) +} + +#' The model-specific parts of an `epidist_transform_data()` call +#' +#' @inheritParams epidist_transform_data +#' @rdname epidist_transform_data_model +#' @family transform +#' @export +epidist_transform_data_model <- function(data, family, formula, ...) { + UseMethod("epidist_transform_data_model") +} + +#' Default method for transforming data for a model +#' +#' @inheritParams epidist_transform_data_model +#' @family transform +#' @export +epidist_transform_data_model.default <- function(data, family, formula, ...) { + return(data) +} diff --git a/_pkgdown.yml b/_pkgdown.yml index 83a47222d..5a1cff9c7 100644 --- a/_pkgdown.yml +++ b/_pkgdown.yml @@ -68,6 +68,10 @@ reference: desc: Functions related to specifying custom `brms` formula contents: - has_concept("formula") +- title: Transform data + desc: Transform data using the formula and family information + contents: + - has_concept("transform_data") - title: Prior distributions desc: Functions for specifying prior distributions contents: diff --git a/man/epidist_transform_data.Rd b/man/epidist_transform_data.Rd new file mode 100644 index 000000000..54b9a3a6a --- /dev/null +++ b/man/epidist_transform_data.Rd @@ -0,0 +1,30 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/transform_data.R +\name{epidist_transform_data} +\alias{epidist_transform_data} +\title{Transform data for an epidist model} +\usage{ +epidist_transform_data(data, family, formula, ...) +} +\arguments{ +\item{data}{An object with class corresponding to an implemented model.} + +\item{family}{A description of the response distribution and link function to +be used in the model created using \code{\link[=epidist_family]{epidist_family()}}.} + +\item{formula}{A formula object created using \code{\link[=epidist_formula]{epidist_formula()}}.} + +\item{...}{Additional arguments passed to \code{fn} method.} +} +\description{ +This function is used within \code{\link[=epidist]{epidist()}} to transform data before passing to +\code{brms}. It is unlikely that as a user you will need this function, but we +export it nonetheless to be transparent about what happens inside of a call +to \code{\link[=epidist]{epidist()}}. +} +\seealso{ +Other transform: +\code{\link{epidist_transform_data_model}()}, +\code{\link{epidist_transform_data_model.default}()} +} +\concept{transform} diff --git a/man/epidist_transform_data_model.Rd b/man/epidist_transform_data_model.Rd new file mode 100644 index 000000000..e7b0e8302 --- /dev/null +++ b/man/epidist_transform_data_model.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/transform_data.R +\name{epidist_transform_data_model} +\alias{epidist_transform_data_model} +\title{The model-specific parts of an \code{epidist_transform_data()} call} +\usage{ +epidist_transform_data_model(data, family, formula, ...) +} +\arguments{ +\item{data}{An object with class corresponding to an implemented model.} + +\item{family}{A description of the response distribution and link function to +be used in the model created using \code{\link[=epidist_family]{epidist_family()}}.} + +\item{formula}{A formula object created using \code{\link[=epidist_formula]{epidist_formula()}}.} + +\item{...}{Additional arguments passed to \code{fn} method.} +} +\description{ +The model-specific parts of an \code{epidist_transform_data()} call +} +\seealso{ +Other transform: +\code{\link{epidist_transform_data}()}, +\code{\link{epidist_transform_data_model.default}()} +} +\concept{transform} diff --git a/man/epidist_transform_data_model.default.Rd b/man/epidist_transform_data_model.default.Rd new file mode 100644 index 000000000..06b95aa8e --- /dev/null +++ b/man/epidist_transform_data_model.default.Rd @@ -0,0 +1,27 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/transform_data.R +\name{epidist_transform_data_model.default} +\alias{epidist_transform_data_model.default} +\title{Default method for transforming data for a model} +\usage{ +\method{epidist_transform_data_model}{default}(data, family, formula, ...) +} +\arguments{ +\item{data}{An object with class corresponding to an implemented model.} + +\item{family}{A description of the response distribution and link function to +be used in the model created using \code{\link[=epidist_family]{epidist_family()}}.} + +\item{formula}{A formula object created using \code{\link[=epidist_formula]{epidist_formula()}}.} + +\item{...}{Additional arguments passed to \code{fn} method.} +} +\description{ +Default method for transforming data for a model +} +\seealso{ +Other transform: +\code{\link{epidist_transform_data}()}, +\code{\link{epidist_transform_data_model}()} +} +\concept{transform} From eb37884aed6a2a5ef5616bb018078e3260734db2 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 29 Nov 2024 18:21:51 +0000 Subject: [PATCH 33/62] start exploring the ebola example --- NAMESPACE | 5 ++++ R/marginal_model.R | 15 ++++++++++++ R/utils.R | 40 +++++++++++++++++++++++++++++++ man/dot-extract_dpar_terms.Rd | 19 +++++++++++++++ man/dot-summarise_n_by_formula.Rd | 24 +++++++++++++++++++ vignettes/ebola.Rmd | 2 +- 6 files changed, 104 insertions(+), 1 deletion(-) create mode 100644 man/dot-extract_dpar_terms.Rd create mode 100644 man/dot-summarise_n_by_formula.Rd diff --git a/NAMESPACE b/NAMESPACE index 9d2ab4934..8faf96881 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -28,6 +28,7 @@ S3method(epidist_stancode,default) S3method(epidist_stancode,epidist_latent_model) S3method(epidist_stancode,epidist_marginal_model) S3method(epidist_transform_data_model,default) +S3method(epidist_transform_data_model,epidist_marginal_model) export(Gamma) export(add_mean_sd) export(as_epidist_latent_model) @@ -89,14 +90,18 @@ importFrom(cli,cli_abort) importFrom(cli,cli_alert_info) importFrom(cli,cli_inform) importFrom(cli,cli_warn) +importFrom(dplyr,across) importFrom(dplyr,bind_cols) importFrom(dplyr,bind_rows) importFrom(dplyr,filter) importFrom(dplyr,full_join) +importFrom(dplyr,group_by) importFrom(dplyr,mutate) importFrom(dplyr,select) +importFrom(dplyr,summarise) importFrom(lubridate,days) importFrom(lubridate,is.timepoint) +importFrom(purrr,map_chr) importFrom(purrr,map_dbl) importFrom(stats,Gamma) importFrom(stats,as.formula) diff --git a/R/marginal_model.R b/R/marginal_model.R index 6bfe5843b..356189f70 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -123,6 +123,21 @@ epidist_formula_model.epidist_marginal_model <- function( return(formula) } +#' @method epidist_transform_data_model epidist_marginal_model +#' @family marginal_model +#' @importFrom purrr map_chr +#' @export +epidist_transform_data_model.epidist_marginal_model <- function( + data, family, formula, ...) { + required_cols <- c( + "delay_lwr", "delay_upr", "relative_obs_time", "pwindow", "swindow" + ) + trans_data <- data |> + .summarise_n_by_formula(by = required_cols, formula = formula) |> + new_epidist_marginal_model() + return(trans_data) +} + #' @method epidist_stancode epidist_marginal_model #' @importFrom brms stanvar #' @family marginal_model diff --git a/R/utils.R b/R/utils.R index 826beb3e9..51d6b7da7 100644 --- a/R/utils.R +++ b/R/utils.R @@ -204,6 +204,46 @@ return(formula) } +#' Extract distributional parameter terms from a brms formula +#' +#' This function extracts all unique terms from the right-hand side of all +#' distributional parameters in a brms formula. +#' +#' @param formula A `brms formula object +#' @return A character vector of unique terms +#' @keywords internal +.extract_dpar_terms <- function(formula) { + terms <- brms::brmsterms(formula) + # Extract all terms from the right hand side of all dpars + dpar_terms <- purrr::map(terms$dpars, \(x) all.vars(x$allvars)) + dpar_terms <- unique(unlist(dpar_terms)) + return(dpar_terms) +} + +#' Summarise data by grouping variables and count occurrences +#' +#' @param data A `data.frame` to summarise which must contain a `n` column +#' which is a count of occurrences. +#' @param by Character vector of column names to group by. +#' @param formula Optional `brms` formula object to extract additional grouping +#' terms from. +#' @return A `data.frame` summarised by the grouping variables with counts +#' @keywords internal +#' @importFrom dplyr group_by summarise across +.summarise_n_by_formula <- function(data, by = character(), formula = NULL) { + if (!is.null(formula)) { + formula_terms <- .extract_dpar_terms(formula) + by <- c(by, formula_terms) + } + # Remove duplicates + by <- unique(by) + + data |> + tibble::as_tibble() |> + summarise(n = sum(.data$n), .by = dplyr::all_of(by)) +} + + #' Rename the columns of a `data.frame` #' #' @param df ... diff --git a/man/dot-extract_dpar_terms.Rd b/man/dot-extract_dpar_terms.Rd new file mode 100644 index 000000000..3ae238bec --- /dev/null +++ b/man/dot-extract_dpar_terms.Rd @@ -0,0 +1,19 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{.extract_dpar_terms} +\alias{.extract_dpar_terms} +\title{Extract distributional parameter terms from a brms formula} +\usage{ +.extract_dpar_terms(formula) +} +\arguments{ +\item{formula}{A `brms formula object} +} +\value{ +A character vector of unique terms +} +\description{ +This function extracts all unique terms from the right-hand side of all +distributional parameters in a brms formula. +} +\keyword{internal} diff --git a/man/dot-summarise_n_by_formula.Rd b/man/dot-summarise_n_by_formula.Rd new file mode 100644 index 000000000..f68b462b3 --- /dev/null +++ b/man/dot-summarise_n_by_formula.Rd @@ -0,0 +1,24 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/utils.R +\name{.summarise_n_by_formula} +\alias{.summarise_n_by_formula} +\title{Summarise data by grouping variables and count occurrences} +\usage{ +.summarise_n_by_formula(data, by = character(), formula = NULL) +} +\arguments{ +\item{data}{A \code{data.frame} to summarise which must contain a \code{n} column +which is a count of occurrences.} + +\item{by}{Character vector of column names to group by.} + +\item{formula}{Optional \code{brms} formula object to extract additional grouping +terms from.} +} +\value{ +A \code{data.frame} summarised by the grouping variables with counts +} +\description{ +Summarise data by grouping variables and count occurrences +} +\keyword{internal} diff --git a/vignettes/ebola.Rmd b/vignettes/ebola.Rmd index bd1c199fe..549e90b0a 100644 --- a/vignettes/ebola.Rmd +++ b/vignettes/ebola.Rmd @@ -219,7 +219,7 @@ Second, because we also did not supply an observation time column (`obs_date`), To prepare the data for use with the latent individual model, we define the data as being a `epidist_latent_model` model object: ```{r} -obs_prep <- as_epidist_latent_model(linelist_data) +obs_prep <- as_epidist_marginal_model(linelist_data) head(obs_prep) ``` From 0adba0efdbc564dbb465b5c8cc01fe45ac0b3071 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 2 Dec 2024 11:07:53 +0000 Subject: [PATCH 34/62] add a helper to find meaningful relative_obs_times --- R/marginal_model.R | 31 +++++++++++++++++-- ...st_marginal_model.epidist_linelist_data.Rd | 7 ++++- 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/R/marginal_model.R b/R/marginal_model.R index 356189f70..8811ff240 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -10,11 +10,16 @@ as_epidist_marginal_model <- function(data) { #' The marginal model method for `epidist_linelist_data` objects #' #' @param data An `epidist_linelist_data` object +#' @param obs_time_threshold Ratio used to determine threshold for setting +#' relative observation times to Inf. Observation times greater than +#' `obs_time_threshold` times the maximum delay will be set to Inf to improve +#' model efficiency. Default is 2. #' @method as_epidist_marginal_model epidist_linelist_data #' @family marginal_model #' @autoglobal #' @export -as_epidist_marginal_model.epidist_linelist_data <- function(data) { +as_epidist_marginal_model.epidist_linelist_data <- function( + data, obs_time_threshold = 2) { assert_epidist(data) data <- data |> @@ -22,11 +27,32 @@ as_epidist_marginal_model.epidist_linelist_data <- function(data) { pwindow = .data$ptime_upr - .data$ptime_lwr, swindow = .data$stime_upr - .data$stime_lwr, relative_obs_time = .data$obs_time - .data$ptime_lwr, + orig_relative_obs_time = .data$obs_time - .data$ptime_lwr, delay_lwr = .data$stime_lwr - .data$ptime_lwr, delay_upr = .data$stime_upr - .data$ptime_lwr, n = 1 ) + # Calculate maximum delay + max_delay <- max(data$delay_upr, na.rm = TRUE) + threshold <- max_delay * obs_time_threshold + + # Count observations beyond threshold + n_beyond <- sum(data$relative_obs_time > threshold, na.rm = TRUE) + + if (n_beyond > 0) { + cli::cli_inform(c( + "!" = paste0( + "Setting {n_beyond} observation time{?s} beyond ", + "{threshold} (={obs_time_threshold}x max delay) to Inf. ", + "This improves model efficiency by reducing unique observation times ", + "while maintaining model accuracy as these times should have ", + "negligible impact." + ) + )) + data$relative_obs_time[data$relative_obs_time > threshold] <- Inf + } + data <- new_epidist_marginal_model(data) assert_epidist(data) return(data) @@ -50,13 +76,14 @@ assert_epidist.epidist_marginal_model <- function(data, ...) { assert_data_frame(data) assert_names(names(data), must.include = c( "pwindow", "swindow", "delay_lwr", "delay_upr", "n", - "relative_obs_time" + "relative_obs_time", "orig_relative_obs_time" )) assert_numeric(data$pwindow, lower = 0) assert_numeric(data$swindow, lower = 0) assert_integerish(data$delay_lwr) assert_integerish(data$delay_upr) assert_numeric(data$relative_obs_time) + assert_numeric(data$orig_relative_obs_time) if (!all(abs(data$delay_upr - (data$delay_lwr + data$swindow)) < 1e-10)) { cli::cli_abort( "delay_upr must equal delay_lwr + swindow" diff --git a/man/as_epidist_marginal_model.epidist_linelist_data.Rd b/man/as_epidist_marginal_model.epidist_linelist_data.Rd index a77576f6d..2754f4058 100644 --- a/man/as_epidist_marginal_model.epidist_linelist_data.Rd +++ b/man/as_epidist_marginal_model.epidist_linelist_data.Rd @@ -4,10 +4,15 @@ \alias{as_epidist_marginal_model.epidist_linelist_data} \title{The marginal model method for \code{epidist_linelist_data} objects} \usage{ -\method{as_epidist_marginal_model}{epidist_linelist_data}(data) +\method{as_epidist_marginal_model}{epidist_linelist_data}(data, obs_time_threshold = 2) } \arguments{ \item{data}{An \code{epidist_linelist_data} object} + +\item{obs_time_threshold}{Ratio used to determine threshold for setting +relative observation times to Inf. Observation times greater than +\code{obs_time_threshold} times the maximum delay will be set to Inf to improve +model efficiency. Default is 2.} } \description{ The marginal model method for \code{epidist_linelist_data} objects From 7c6f8f81abca37ce2186316a9c85b85cf9cddb3e Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 2 Dec 2024 11:54:25 +0000 Subject: [PATCH 35/62] get the full ebola vignette working with new variable requirements --- R/marginal_model.R | 26 +++++++++++++++++++++-- vignettes/ebola.Rmd | 48 +++++++++++++++++++++++++++++++------------ vignettes/epidist.Rmd | 2 +- 3 files changed, 60 insertions(+), 16 deletions(-) diff --git a/R/marginal_model.R b/R/marginal_model.R index 8811ff240..147048ec2 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -76,14 +76,13 @@ assert_epidist.epidist_marginal_model <- function(data, ...) { assert_data_frame(data) assert_names(names(data), must.include = c( "pwindow", "swindow", "delay_lwr", "delay_upr", "n", - "relative_obs_time", "orig_relative_obs_time" + "relative_obs_time" )) assert_numeric(data$pwindow, lower = 0) assert_numeric(data$swindow, lower = 0) assert_integerish(data$delay_lwr) assert_integerish(data$delay_upr) assert_numeric(data$relative_obs_time) - assert_numeric(data$orig_relative_obs_time) if (!all(abs(data$delay_upr - (data$delay_lwr + data$swindow)) < 1e-10)) { cli::cli_abort( "delay_upr must equal delay_lwr + swindow" @@ -159,9 +158,32 @@ epidist_transform_data_model.epidist_marginal_model <- function( required_cols <- c( "delay_lwr", "delay_upr", "relative_obs_time", "pwindow", "swindow" ) + n_rows_before <- nrow(data) + trans_data <- data |> .summarise_n_by_formula(by = required_cols, formula = formula) |> new_epidist_marginal_model() + n_rows_after <- nrow(trans_data) + if (n_rows_before > n_rows_after) { + cli::cli_inform("Data summarised by unique combinations of:") + + if (length(all.vars(formula[[3]])) > 0) { + cli::cli_inform( + paste0("* Formula terms: {.code {all.vars(formula[[3]])}}") + ) + } + + cli::cli_inform(paste0( + "* Delay windows: delay bounds, observation time, ", + "and primary censoring window" + )) + + cli::cli_inform(paste0( + "i" = "Reduced from {n_rows_before} to {n_rows_after} rows. ", # nolint + "This should improve model efficiency with no loss of information." + )) + } + return(trans_data) } diff --git a/vignettes/ebola.Rmd b/vignettes/ebola.Rmd index 549e90b0a..5ac15f6bc 100644 --- a/vignettes/ebola.Rmd +++ b/vignettes/ebola.Rmd @@ -241,7 +241,7 @@ fit <- epidist( algorithm = "sampling", chains = 2, cores = 2, - refresh = as.integer(interactive()), + refresh = ifelse(interactive(), 250, 0), seed = 1, backend = "cmdstanr" ) @@ -267,7 +267,7 @@ fit_sex <- epidist( algorithm = "sampling", chains = 2, cores = 2, - refresh = as.integer(interactive()), + refresh = ifelse(interactive(), 250, 0), seed = 1, backend = "cmdstanr" ) @@ -298,7 +298,7 @@ fit_sex_district <- epidist( algorithm = "sampling", chains = 2, cores = 2, - refresh = as.integer(interactive()), + refresh = ifelse(interactive(), 250, 0), seed = 1, backend = "cmdstanr" ) @@ -321,9 +321,15 @@ In Figure \@ref(fig:epred) we show the posterior expectation of the delay distri Figure \@ref(fig:epred)B illustrates the higher mean of men as compared with women. ```{r} +# add dummmy variables +add_marginal_dummy_vars <- function(data) { + data |> + mutate(relative_obs_time = NA, pwindow = NA, delay_upr = NA, swindow = NA) +} + epred_draws <- obs_prep |> data_grid(NA) |> - mutate(relative_obs_time = NA, pwindow = NA, swindow = NA) |> + add_marginal_dummy_vars() |> add_epred_draws(fit, dpar = TRUE) epred_base_figure <- epred_draws |> @@ -334,7 +340,7 @@ epred_base_figure <- epred_draws |> epred_draws_sex <- obs_prep |> data_grid(sex) |> - mutate(relative_obs_time = NA, pwindow = NA, swindow = NA) |> + add_marginal_dummy_vars() |> add_epred_draws(fit_sex, dpar = TRUE) epred_sex_figure <- epred_draws_sex |> @@ -345,7 +351,7 @@ epred_sex_figure <- epred_draws_sex |> epred_draws_sex_district <- obs_prep |> data_grid(sex, district) |> - mutate(relative_obs_time = NA, pwindow = NA, swindow = NA) |> + add_marginal_dummy_vars() |> add_epred_draws(fit_sex_district, dpar = TRUE) epred_sex_district_figure <- epred_draws_sex_district |> @@ -376,7 +382,7 @@ For example, for the `mu` parameter in the sex-district stratified model (Figure linpred_draws_sex_district <- obs_prep |> as.data.frame() |> data_grid(sex, district) |> - mutate(relative_obs_time = NA, pwindow = NA, swindow = NA) |> + add_marginal_dummy_vars() |> add_linpred_draws(fit_sex_district, dpar = TRUE) ``` @@ -404,9 +410,17 @@ To do this, we set each of `pwindow` and `swindow` to 1 for daily censoring, and Figure \@ref(fig:pmf) shows the result, where the few delays greater than 30 are omitted from the figure. ```{r} +add_marginal_pmf_vars <- function(data) { + data |> + mutate( + relative_obs_time = 1000, pwindow = 1, swindow = 1, + delay_upr = .data$delay_lwr + .data$swindow + ) +} + draws_pmf <- obs_prep |> as.data.frame() |> - mutate(relative_obs_time = 1000, pwindow = 1, swindow = 1) |> + add_marginal_pmf_vars() |> add_predicted_draws(fit, ndraws = 1000) pmf_base_figure <- ggplot(draws_pmf, aes(x = .prediction)) + @@ -418,7 +432,7 @@ pmf_base_figure <- ggplot(draws_pmf, aes(x = .prediction)) + draws_sex_pmf <- obs_prep |> as.data.frame() |> data_grid(sex) |> - mutate(relative_obs_time = 1000, pwindow = 1, swindow = 1) |> + add_marginal_pmf_vars() |> add_predicted_draws(fit_sex, ndraws = 1000) pmf_sex_figure <- draws_sex_pmf |> @@ -432,7 +446,7 @@ pmf_sex_figure <- draws_sex_pmf |> draws_sex_district_pmf <- obs_prep |> as.data.frame() |> data_grid(sex, district) |> - mutate(relative_obs_time = 1000, pwindow = 1, swindow = 1) |> + add_marginal_pmf_vars() |> add_predicted_draws(fit_sex_district, ndraws = 1000) pmf_sex_district_figure <- draws_sex_district_pmf |> @@ -468,9 +482,17 @@ The posterior predictive distribution under no truncation and no censoring. That is to produce continuous delay times (Figure \@ref(fig:pdf)): ```{r} +add_marginal_pdf_vars <- function(data) { + data |> + mutate( + relative_obs_time = 1000, pwindow = 0, swindow = 0, + delay_upr = .data$delay_lwr + .data$swindow + ) +} + draws_pdf <- obs_prep |> as.data.frame() |> - mutate(relative_obs_time = 1000, pwindow = 0, swindow = 0) |> + add_marginal_pdf_vars() |> add_predicted_draws(fit, ndraws = 1000) pdf_base_figure <- ggplot(draws_pdf, aes(x = .prediction)) + @@ -482,7 +504,7 @@ pdf_base_figure <- ggplot(draws_pdf, aes(x = .prediction)) + draws_sex_pdf <- obs_prep |> as.data.frame() |> data_grid(sex) |> - mutate(relative_obs_time = 1000, pwindow = 0, swindow = 0) |> + add_marginal_pdf_vars() |> add_predicted_draws(fit_sex, ndraws = 1000) pdf_sex_figure <- draws_sex_pdf |> @@ -496,7 +518,7 @@ pdf_sex_figure <- draws_sex_pdf |> draws_sex_district_pdf <- obs_prep |> as.data.frame() |> data_grid(sex, district) |> - mutate(relative_obs_time = 1000, pwindow = 0, swindow = 0) |> + add_marginal_pdf_vars() |> add_predicted_draws(fit_sex_district, ndraws = 1000) pdf_sex_district_figure <- draws_sex_district_pdf |> diff --git a/vignettes/epidist.Rmd b/vignettes/epidist.Rmd index 4dee7b343..4198b4618 100644 --- a/vignettes/epidist.Rmd +++ b/vignettes/epidist.Rmd @@ -285,7 +285,7 @@ In particular, we use the the No-U-Turn Sampler (NUTS) Markov chain Monte Carlo ```{r} fit <- epidist( - data = data, chains = 2, cores = 2, refresh = as.integer(interactive()) + data = data, chains = 2, cores = 2, refresh = ifelse(interactive(), 250, 0) ) ``` From 85c690cd25633cf6dd3e8d60daa2d879c9bf8ed5 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 2 Dec 2024 12:00:35 +0000 Subject: [PATCH 36/62] improve return messages --- R/marginal_model.R | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/R/marginal_model.R b/R/marginal_model.R index 147048ec2..45990f9b3 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -169,12 +169,12 @@ epidist_transform_data_model.epidist_marginal_model <- function( if (length(all.vars(formula[[3]])) > 0) { cli::cli_inform( - paste0("* Formula terms: {.code {all.vars(formula[[3]])}}") + paste0("* Formula variables: {.code {all.vars(formula[[3]])}}") ) } cli::cli_inform(paste0( - "* Delay windows: delay bounds, observation time, ", + "* Model variables: delay bounds, observation time, ", "and primary censoring window" )) From 099278b9c1197141508031480e639a5b38a6c27b Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 2 Dec 2024 12:04:53 +0000 Subject: [PATCH 37/62] update approx vignette --- vignettes/approx-inference.Rmd | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vignettes/approx-inference.Rmd b/vignettes/approx-inference.Rmd index 819992611..bb91d8eee 100644 --- a/vignettes/approx-inference.Rmd +++ b/vignettes/approx-inference.Rmd @@ -146,7 +146,7 @@ linelist_data <- as_epidist_linelist_data( obs_time = obs_cens_trunc_samp$obs_time ) -data <- as_epidist_latent_model(linelist_data) +data <- as_epidist_marginal_model(linelist_data) t <- proc.time() fit_hmc <- epidist(data = data, algorithm = "sampling", backend = "cmdstanr") From 4d98a7b489dc53e28606a4d11c07f42def7a57fd Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 2 Dec 2024 14:32:44 +0000 Subject: [PATCH 38/62] get ebole vignette passing by checking pp and related inputs --- R/marginal_model.R | 2 +- inst/stan/marginal_model/functions.stan | 9 ++++----- vignettes/ebola.Rmd | 14 +++----------- 3 files changed, 8 insertions(+), 17 deletions(-) diff --git a/R/marginal_model.R b/R/marginal_model.R index 45990f9b3..a75278be8 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -144,7 +144,7 @@ epidist_formula_model.epidist_marginal_model <- function( # data is only used to dispatch on formula <- stats::update( formula, delay_lwr | weights(n) + - vreal(delay_upr, relative_obs_time, pwindow, swindow) ~ . + vreal(relative_obs_time, pwindow, swindow, delay_upr) ~ . ) return(formula) } diff --git a/inst/stan/marginal_model/functions.stan b/inst/stan/marginal_model/functions.stan index 1985413c5..3c1f05cf1 100644 --- a/inst/stan/marginal_model/functions.stan +++ b/inst/stan/marginal_model/functions.stan @@ -11,18 +11,17 @@ * * @param y Real value of observed delay * @param dpars_A Distribution parameters (replaced via regex) - * @param y_upper Upper bound of delay interval * @param relative_obs_t Observation time relative to primary window start * @param pwindow_width Primary window width (actual time scale) * @param swindow_width Secondary window width (actual time scale) + * @param y_upper Upper bound of delay interval * @param primary_params Array of parameters for primary distribution * * @return Log probability mass with censoring adjustment for marginal model */ - real marginal_family_lpmf(data int y, dpars_A, data real y_upper, - data real relative_obs_t, data real pwindow_width, - data real swindow_width, - array[] real primary_params) { + real marginal_family_lpmf(data int y, dpars_A, data real relative_obs_t, + data real pwindow_width, data real swindow_width, + data real y_upper, array[] real primary_params) { return primarycensored_lpmf( y | dist_id, {dpars_B}, pwindow_width, y_upper, relative_obs_t, diff --git a/vignettes/ebola.Rmd b/vignettes/ebola.Rmd index 5ac15f6bc..c0cded2f6 100644 --- a/vignettes/ebola.Rmd +++ b/vignettes/ebola.Rmd @@ -406,20 +406,18 @@ In this section, we demonstrate how to produce either a discrete probability mas ### Discrete probability mass function To generate a discrete probability mass function (PMF) we predict the delay distribution that would be observed with daily censoring and no right truncation. -To do this, we set each of `pwindow` and `swindow` to 1 for daily censoring, and `relative_obs_time` to 1000 for no censoring. +To do this, we set each of `pwindow` and `swindow` to 1 for daily censoring, and `relative_obs_time` to `Inf` for no censoring. Figure \@ref(fig:pmf) shows the result, where the few delays greater than 30 are omitted from the figure. ```{r} add_marginal_pmf_vars <- function(data) { data |> mutate( - relative_obs_time = 1000, pwindow = 1, swindow = 1, - delay_upr = .data$delay_lwr + .data$swindow + relative_obs_time = Inf, pwindow = 1, swindow = 1, delay_upr = NA ) } draws_pmf <- obs_prep |> - as.data.frame() |> add_marginal_pmf_vars() |> add_predicted_draws(fit, ndraws = 1000) @@ -430,7 +428,6 @@ pmf_base_figure <- ggplot(draws_pmf, aes(x = .prediction)) + theme_minimal() draws_sex_pmf <- obs_prep |> - as.data.frame() |> data_grid(sex) |> add_marginal_pmf_vars() |> add_predicted_draws(fit_sex, ndraws = 1000) @@ -444,7 +441,6 @@ pmf_sex_figure <- draws_sex_pmf |> theme_minimal() draws_sex_district_pmf <- obs_prep |> - as.data.frame() |> data_grid(sex, district) |> add_marginal_pmf_vars() |> add_predicted_draws(fit_sex_district, ndraws = 1000) @@ -485,13 +481,11 @@ That is to produce continuous delay times (Figure \@ref(fig:pdf)): add_marginal_pdf_vars <- function(data) { data |> mutate( - relative_obs_time = 1000, pwindow = 0, swindow = 0, - delay_upr = .data$delay_lwr + .data$swindow + relative_obs_time = Inf, pwindow = 0, swindow = 0, delay_upr = NA ) } draws_pdf <- obs_prep |> - as.data.frame() |> add_marginal_pdf_vars() |> add_predicted_draws(fit, ndraws = 1000) @@ -502,7 +496,6 @@ pdf_base_figure <- ggplot(draws_pdf, aes(x = .prediction)) + theme_minimal() draws_sex_pdf <- obs_prep |> - as.data.frame() |> data_grid(sex) |> add_marginal_pdf_vars() |> add_predicted_draws(fit_sex, ndraws = 1000) @@ -516,7 +509,6 @@ pdf_sex_figure <- draws_sex_pdf |> theme_minimal() draws_sex_district_pdf <- obs_prep |> - as.data.frame() |> data_grid(sex, district) |> add_marginal_pdf_vars() |> add_predicted_draws(fit_sex_district, ndraws = 1000) From 7c972261fa458a5da38b6cd8ab9b33a572211858 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 2 Dec 2024 15:22:16 +0000 Subject: [PATCH 39/62] add marginal model integration tests --- R/marginal_model.R | 5 +- tests/testthat/setup.R | 22 +++++++- tests/testthat/test-int-marginal_model.R | 72 ++++++++++++++++++++---- tests/testthat/test-utils.R | 55 ++++++++++++++++++ 4 files changed, 139 insertions(+), 15 deletions(-) diff --git a/R/marginal_model.R b/R/marginal_model.R index a75278be8..7723140b6 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -167,9 +167,10 @@ epidist_transform_data_model.epidist_marginal_model <- function( if (n_rows_before > n_rows_after) { cli::cli_inform("Data summarised by unique combinations of:") - if (length(all.vars(formula[[3]])) > 0) { + formula_vars <- setdiff(names(trans_data), c(required_cols, "n")) + if (length(formula_vars) > 0) { cli::cli_inform( - paste0("* Formula variables: {.code {all.vars(formula[[3]])}}") + paste0("* Formula variables: {.code {formula_vars}}") ) } diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index 03271e4cf..f9d76a014 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -113,13 +113,16 @@ sim_obs_sex <- as_epidist_linelist_data( sim_obs_sex$obs_time, sex = sim_obs_sex$sex ) - prep_obs <- as_epidist_latent_model(sim_obs) prep_naive_obs <- as_epidist_naive_model(sim_obs) prep_marginal_obs <- as_epidist_marginal_model(sim_obs) prep_obs_gamma <- as_epidist_latent_model(sim_obs_gamma) prep_obs_sex <- as_epidist_latent_model(sim_obs_sex) +prep_marginal_obs <- as_epidist_marginal_model(sim_obs) +prep_marginal_obs_gamma <- as_epidist_marginal_model(sim_obs_gamma) +prep_marginal_obs_sex <- as_epidist_marginal_model(sim_obs_sex) + if (not_on_cran()) { set.seed(1) fit <- epidist( @@ -130,6 +133,10 @@ if (not_on_cran()) { fit_rstan <- epidist( data = prep_obs, seed = 1, chains = 2, cores = 2, silent = 2, refresh = 0 ) + fit_marginal <- suppressMessages(epidist( + data = prep_marginal_obs, seed = 1, chains = 2, cores = 2, silent = 2, + refresh = 0, backend = "cmdstanr" + )) fit_gamma <- epidist( data = prep_obs_gamma, family = Gamma(link = "log"), @@ -137,10 +144,23 @@ if (not_on_cran()) { backend = "cmdstanr" ) + fit_marginal_gamma <- suppressMessages(epidist( + data = prep_marginal_obs_gamma, family = Gamma(link = "log"), + seed = 1, chains = 2, cores = 2, silent = 2, refresh = 0, + backend = "cmdstanr" + )) + fit_sex <- epidist( data = prep_obs_sex, formula = bf(mu ~ 1 + sex, sigma ~ 1 + sex), seed = 1, silent = 2, refresh = 0, cores = 2, chains = 2, backend = "cmdstanr" ) + + fit_marginal_sex <- suppressMessages(epidist( + data = prep_marginal_obs_sex, + formula = bf(mu ~ 1 + sex, sigma ~ 1 + sex), + seed = 1, silent = 2, refresh = 50, + cores = 2, chains = 2, backend = "cmdstanr" + )) } diff --git a/tests/testthat/test-int-marginal_model.R b/tests/testthat/test-int-marginal_model.R index 97099b747..d44131b90 100644 --- a/tests/testthat/test-int-marginal_model.R +++ b/tests/testthat/test-int-marginal_model.R @@ -6,10 +6,10 @@ test_that("epidist.epidist_marginal_model Stan code has no syntax errors in the default case", { # nolint: line_length_linter. skip_on_cran() - stancode <- epidist( + stancode <- suppressMessages(epidist( data = prep_marginal_obs, fn = brms::make_stancode - ) + )) mod <- cmdstanr::cmdstan_model( stan_file = cmdstanr::write_stan_file(stancode), compile = FALSE ) @@ -17,18 +17,66 @@ test_that("epidist.epidist_marginal_model Stan code has no syntax errors in the }) test_that("epidist.epidist_marginal_model fits and the MCMC converges in the default case", { # nolint: line_length_linter. + # Note: this test is stochastic. See note at the top of this script + skip_on_cran() + expect_s3_class(fit_marginal, "brmsfit") + expect_s3_class(fit_marginal, "epidist_fit") + expect_convergence(fit_marginal) +}) + +test_that("epidist.epidist_marginal_model recovers the simulation settings for the delay distribution in the default case", { # nolint: line_length_linter. # Note: this test is stochastic. See note at the top of this script skip_on_cran() set.seed(1) - fit <- epidist( - data = prep_marginal_obs, - seed = 1, - silent = 2, refresh = 0, - cores = 2, - chains = 2, - backend = "cmdstanr" + pred <- predict_delay_parameters(fit_marginal) + expect_equal(mean(pred$mu), meanlog, tolerance = 0.1) + expect_equal(mean(pred$sigma), sdlog, tolerance = 0.1) +}) + +test_that("epidist.epidist_marginal_model fits and the MCMC converges in the gamma delay case", { # nolint: line_length_linter. + # Note: this test is stochastic. See note at the top of this script + skip_on_cran() + set.seed(1) + expect_s3_class(fit_marginal_gamma, "brmsfit") + expect_s3_class(fit_marginal_gamma, "epidist_fit") + expect_convergence(fit_marginal_gamma) +}) + +test_that("epidist.epidist_marginal_model recovers the simulation settings for the delay distribution in the gamma delay case", { # nolint: line_length_linter. + # Note: this test is stochastic. See note at the top of this script + skip_on_cran() + set.seed(1) + draws_gamma <- posterior::as_draws_df(fit_marginal_gamma$fit) + draws_gamma_mu <- exp(draws_gamma$Intercept) + draws_gamma_shape <- exp(draws_gamma$Intercept_shape) + draws_gamma_mu_ecdf <- ecdf(draws_gamma_mu) + draws_gamma_shape_ecdf <- ecdf(draws_gamma_shape) + quantile_mu <- draws_gamma_mu_ecdf(mu) + quantile_shape <- draws_gamma_shape_ecdf(shape) + expect_gte(quantile_mu, 0.025) + expect_lte(quantile_mu, 0.975) + expect_gte(quantile_shape, 0.025) + expect_lte(quantile_shape, 0.975) +}) + +test_that("epidist.epidist_marginal_model fits and recovers a sex effect", { # nolint: line_length_linter. + # Note: this test is stochastic. See note at the top of this script + skip_on_cran() + set.seed(1) + expect_s3_class(fit_marginal_sex, "brmsfit") + expect_s3_class(fit_marginal_sex, "epidist_fit") + expect_convergence(fit_marginal_sex) + + draws <- posterior::as_draws_df(fit_marginal_sex$fit) + expect_equal(mean(draws$b_Intercept), meanlog_m, tolerance = 0.3) + expect_equal( + mean(draws$b_Intercept + draws$b_sex), meanlog_f, + tolerance = 0.3 + ) + expect_equal(mean(exp(draws$b_sigma_Intercept)), sdlog_m, tolerance = 0.3) + expect_equal( + mean(exp(draws$b_sigma_Intercept + draws$b_sigma_sex)), + sdlog_f, + tolerance = 0.3 ) - expect_s3_class(fit, "brmsfit") - expect_s3_class(fit, "epidist_fit") - expect_convergence(fit) }) diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index 1fb3add13..e136db754 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -98,3 +98,58 @@ test_that(".make_intercepts_explicit does not add an intercept if the distributi expect_identical(formula$pforms$mu, formula_updated$pforms$mu) expect_identical(formula$pforms$sigma, formula_updated$pforms$sigma) }) + +test_that( + ".summarise_n_by_formula correctly summarizes counts by grouping variables", + { + df <- tibble::tibble( + x = c(1, 1, 2, 2), + y = c("a", "b", "a", "b"), + n = c(2, 3, 4, 1) + ) + + # Test grouping by single variable + result <- .summarise_n_by_formula(df, by = "x") + expect_identical(result$x, c(1, 2)) + expect_identical(result$n, c(5, 5)) + + # Test grouping by multiple variable + result <- .summarise_n_by_formula(df, by = c("x", "y")) + expect_identical(result$x, c(1, 1, 2, 2)) + expect_identical(result$y, c("a", "b", "a", "b")) + expect_identical(result$n, c(2, 3, 4, 1)) + + # Test with formula + formula <- bf(mu ~ x + y) + result <- .summarise_n_by_formula(df, formula = formula) + expect_identical(result$x, c(1, 1, 2, 2)) + expect_identical(result$y, c("a", "b", "a", "b")) + expect_identical(result$n, c(2, 3, 4, 1)) + + # Test with both by and formula + formula <- bf(mu ~ y) + result <- .summarise_n_by_formula(df, by = "x", formula = formula) + expect_identical(result$x, c(1, 1, 2, 2)) + expect_identical(result$y, c("a", "b", "a", "b")) + expect_identical(result$n, c(2, 3, 4, 1)) + } +) + +test_that( + ".summarise_n_by_formula handles missing grouping variables appropriately", + { + df <- data.frame(x = 1:2, n = c(1, 2)) + expect_error( + .summarise_n_by_formula(df, by = "missing"), + "object 'missing' not found" + ) + } +) + +test_that(".summarise_n_by_formula requires n column in data", { + df <- data.frame(x = 1:2) + expect_error( + .summarise_n_by_formula(df, by = "x"), + "Column `n` not found in `.data`." + ) +}) From dd3831a678c1d17cd787fedcda1718c6c433dacb Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 2 Dec 2024 16:06:16 +0000 Subject: [PATCH 40/62] expand post processing tests --- R/marginal_model.R | 2 + tests/testthat/test-gen.R | 190 +++++++++++++++--------------- tests/testthat/test-postprocess.R | 104 +++++++++------- 3 files changed, 162 insertions(+), 134 deletions(-) diff --git a/R/marginal_model.R b/R/marginal_model.R index 7723140b6..932239ce2 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -30,6 +30,7 @@ as_epidist_marginal_model.epidist_linelist_data <- function( orig_relative_obs_time = .data$obs_time - .data$ptime_lwr, delay_lwr = .data$stime_lwr - .data$ptime_lwr, delay_upr = .data$stime_upr - .data$ptime_lwr, + .row_id = dplyr::row_number(), n = 1 ) @@ -83,6 +84,7 @@ assert_epidist.epidist_marginal_model <- function(data, ...) { assert_integerish(data$delay_lwr) assert_integerish(data$delay_upr) assert_numeric(data$relative_obs_time) + assert_integerish(data$.row_id, lower = 1) if (!all(abs(data$delay_upr - (data$delay_lwr + data$swindow)) < 1e-10)) { cli::cli_abort( "delay_upr must equal delay_lwr + swindow" diff --git a/tests/testthat/test-gen.R b/tests/testthat/test-gen.R index 37b7d7bf3..93d22a8d4 100644 --- a/tests/testthat/test-gen.R +++ b/tests/testthat/test-gen.R @@ -1,119 +1,125 @@ test_that("epidist_gen_posterior_predict returns a function that outputs positive integers with length equal to draws", { # nolint: line_length_linter. skip_on_cran() - # Test lognormal - prep <- brms::prepare_predictions(fit) - i <- 1 - predict_fn <- epidist_gen_posterior_predict(lognormal()) - pred_i <- predict_fn(i = i, prep) - expect_identical(floor(pred_i), pred_i) - expect_length(pred_i, prep$ndraws) - expect_gte(min(pred_i), 0) - # Test gamma - prep_gamma <- brms::prepare_predictions(fit_gamma) - predict_fn_gamma <- epidist_gen_posterior_predict(Gamma()) - pred_i_gamma <- predict_fn_gamma(i = i, prep_gamma) - expect_identical(floor(pred_i_gamma), pred_i_gamma) - expect_length(pred_i_gamma, prep_gamma$ndraws) - expect_gte(min(pred_i_gamma), 0) + # Helper function to test predictions + test_predictions <- function(fit, family) { + prep <- brms::prepare_predictions(fit) + i <- 1 + predict_fn <- epidist_gen_posterior_predict(family) + pred_i <- predict_fn(i = i, prep) + expect_identical(floor(pred_i), pred_i) + expect_length(pred_i, prep$ndraws) + expect_gte(min(pred_i), 0) + } + + # Test lognormal - latent and marginal + test_predictions(fit, lognormal()) + test_predictions(fit_marginal, lognormal()) + + # Test gamma - latent and marginal + test_predictions(fit_gamma, Gamma()) + test_predictions(fit_marginal_gamma, Gamma()) }) test_that("epidist_gen_posterior_predict returns a function that errors for i out of bounds", { # nolint: line_length_linter. skip_on_cran() - # Test lognormal - prep <- brms::prepare_predictions(fit) - i_out_of_bounds <- length(prep$data$Y) + 1 - predict_fn <- epidist_gen_posterior_predict(lognormal()) - expect_warning( - expect_error( - predict_fn(i = i_out_of_bounds, prep) + + # Helper function to test out of bounds errors + test_out_of_bounds <- function(fit, family) { + prep <- brms::prepare_predictions(fit) + i_out_of_bounds <- length(prep$data$Y) + 1 + predict_fn <- epidist_gen_posterior_predict(family) + expect_warning( + expect_error( + predict_fn(i = i_out_of_bounds, prep) + ) ) - ) + } - # Test gamma - prep_gamma <- brms::prepare_predictions(fit_gamma) - i_out_of_bounds_gamma <- length(prep_gamma$data$Y) + 1 - predict_fn_gamma <- epidist_gen_posterior_predict(Gamma()) - expect_warning( - expect_error(predict_fn_gamma(i = i_out_of_bounds_gamma, prep_gamma)) - ) + # Test lognormal - latent and marginal + test_out_of_bounds(fit, lognormal()) + test_out_of_bounds(fit_marginal, lognormal()) + + # Test gamma - latent and marginal + test_out_of_bounds(fit_gamma, Gamma()) + test_out_of_bounds(fit_marginal_gamma, Gamma()) }) test_that("epidist_gen_posterior_predict returns a function that can generate predictions with no censoring", { # nolint: line_length_linter. skip_on_cran() - # Test lognormal - predict_fn <- epidist_gen_posterior_predict(lognormal()) - draws <- data.frame(relative_obs_time = 1000, pwindow = 0, swindow = 0) |> - tidybayes::add_predicted_draws(fit, ndraws = 100) - expect_identical(draws$.draw, 1:100) - pred <- draws$.prediction - expect_gte(min(pred), 0) - expect_true(all(abs(pred - round(pred)) > .Machine$double.eps^0.5)) - # Test gamma - predict_fn_gamma <- epidist_gen_posterior_predict(Gamma()) - draws_gamma <- data.frame( - relative_obs_time = 1000, pwindow = 0, swindow = 0 - ) |> - tidybayes::add_predicted_draws(fit_gamma, ndraws = 100) - expect_identical(draws_gamma$.draw, 1:100) - pred_gamma <- draws_gamma$.prediction - expect_gte(min(pred_gamma), 0) - expect_true( - all(abs(pred_gamma - round(pred_gamma)) > .Machine$double.eps^0.5) - ) + # Helper function to test uncensored predictions + test_uncensored <- function(fit, family) { + predict_fn <- epidist_gen_posterior_predict(family) + draws <- data.frame( + relative_obs_time = Inf, pwindow = 0, swindow = 0, delay_upr = NA + ) |> + tidybayes::add_predicted_draws(fit, ndraws = 100) + expect_identical(draws$.draw, 1:100) + pred <- draws$.prediction + expect_gte(min(pred), 0) + expect_true(all(abs(pred - round(pred)) > .Machine$double.eps^0.5)) + } + + # Test lognormal - latent and marginal + test_uncensored(fit, lognormal()) + test_uncensored(fit_marginal, lognormal()) + + # Test gamma - latent and marginal + test_uncensored(fit_gamma, Gamma()) + test_uncensored(fit_marginal_gamma, Gamma()) }) test_that("epidist_gen_posterior_predict returns a function that predicts delays in the 95% credible interval", { # nolint: line_length_linter. skip_on_cran() - # Test lognormal - prep <- brms::prepare_predictions(fit) - prep$ndraws <- 1000 # Down from the 4000 for time saving - predict_fn <- epidist_gen_posterior_predict(lognormal()) - q <- purrr::map_vec(seq_along(prep$data$Y), function(i) { - y <- predict_fn(i, prep) - ecdf <- ecdf(y) - q <- ecdf(prep$data$Y[i]) - return(q) - }) - expect_lt(quantile(q, 0.1), 0.3) - expect_gt(quantile(q, 0.9), 0.7) - expect_lt(min(q), 0.1) - expect_gt(max(q), 0.9) - expect_lt(mean(q), 0.65) - expect_gt(mean(q), 0.35) - # Test gamma - prep_gamma <- brms::prepare_predictions(fit_gamma) - prep_gamma$ndraws <- 1000 - predict_fn_gamma <- epidist_gen_posterior_predict(Gamma()) - q_gamma <- purrr::map_vec(seq_along(prep_gamma$data$Y), function(i) { - y <- predict_fn_gamma(i, prep_gamma) - ecdf <- ecdf(y) - q <- ecdf(prep_gamma$data$Y[i]) - return(q) - }) - expect_lt(quantile(q_gamma, 0.1), 0.3) - expect_gt(quantile(q_gamma, 0.9), 0.7) - expect_lt(min(q_gamma), 0.1) - expect_gt(max(q_gamma), 0.9) - expect_lt(mean(q_gamma), 0.65) - expect_gt(mean(q_gamma), 0.35) + # Helper function to test credible intervals + test_credible_intervals <- function(fit, family) { + prep <- brms::prepare_predictions(fit) + prep$ndraws <- 1000 # Down from the 4000 for time saving + predict_fn <- epidist_gen_posterior_predict(family) + q <- purrr::map_vec(seq_along(prep$data$Y), function(i) { + y <- predict_fn(i, prep) + ecdf <- ecdf(y) + q <- ecdf(prep$data$Y[i]) + return(q) + }) + expect_lt(quantile(q, 0.1), 0.3) + expect_gt(quantile(q, 0.9), 0.7) + expect_lt(min(q), 0.1) + expect_gt(max(q), 0.9) + expect_lt(mean(q), 0.65) + expect_gt(mean(q), 0.35) + } + + # Test lognormal - latent and marginal + test_credible_intervals(fit, lognormal()) + test_credible_intervals(fit_marginal, lognormal()) + + # Test gamma - latent and marginal + test_credible_intervals(fit_gamma, Gamma()) + test_credible_intervals(fit_marginal_gamma, Gamma()) }) test_that("epidist_gen_posterior_epred returns a function that creates arrays with correct dimensions", { # nolint: line_length_linter. skip_on_cran() - # Test lognormal - epred <- prep_obs |> - tidybayes::add_epred_draws(fit) - expect_equal(mean(epred$.epred), 5.97, tolerance = 0.1) - expect_gte(min(epred$.epred), 0) - # Test gamma - epred_gamma <- prep_obs |> - tidybayes::add_epred_draws(fit_gamma) - expect_equal(mean(epred_gamma$.epred), 6.56, tolerance = 0.1) - expect_gte(min(epred_gamma$.epred), 0) + # Helper function to test epred + test_epred <- function(fit, expected_mean) { + epred <- prep_obs |> + mutate(delay_upr = NA) |> + tidybayes::add_epred_draws(fit) + expect_equal(mean(epred$.epred), expected_mean, tolerance = 0.1) + expect_gte(min(epred$.epred), 0) + } + + # Test lognormal - latent and marginal + test_epred(fit, 5.97) + test_epred(fit_marginal, 5.97) + + # Test gamma - latent and marginal + test_epred(fit_gamma, 6.56) + test_epred(fit_marginal_gamma, 6.56) }) test_that("epidist_gen_log_lik returns a function that produces valid log likelihoods", { # nolint: line_length_linter. diff --git a/tests/testthat/test-postprocess.R b/tests/testthat/test-postprocess.R index 798d591a2..5d898ad9b 100644 --- a/tests/testthat/test-postprocess.R +++ b/tests/testthat/test-postprocess.R @@ -1,52 +1,72 @@ -test_that("predict_delay_parameters works with NULL newdata and the latent lognormal model", { # nolint: line_length_linter. - skip_on_cran() - set.seed(1) - pred <- predict_delay_parameters(fit) - expect_s3_class(pred, "lognormal_samples") - expect_s3_class(pred, "data.frame") - expect_named(pred, c("draw", "index", "mu", "sigma", "mean", "sd")) - expect_true(all(pred$mean > 0)) - expect_true(all(pred$sd > 0)) - expect_length(unique(pred$index), nrow(prep_obs)) - expect_length(unique(pred$draw), summary(fit)$total_ndraws) -}) +test_that( + "predict_delay_parameters works with NULL newdata and the latent and marginal lognormal model", # nolint: line_length_linter. + { + skip_on_cran() + + # Helper function to test predictions + test_predictions <- function(fit, expected_rows = nrow(prep_obs)) { + set.seed(1) + pred <- predict_delay_parameters(fit) + expect_s3_class(pred, "lognormal_samples") + expect_s3_class(pred, "data.frame") + expect_named(pred, c("draw", "index", "mu", "sigma", "mean", "sd")) + expect_true(all(pred$mean > 0)) + expect_true(all(pred$sd > 0)) + expect_length(unique(pred$index), expected_rows) + expect_length(unique(pred$draw), summary(fit)$total_ndraws) + } + + # Test latent and marginal models + test_predictions(fit) + test_predictions(fit_marginal, expected_rows = 144) + } +) test_that("predict_delay_parameters accepts newdata arguments and prediction by sex recovers underlying parameters", { # nolint: line_length_linter. skip_on_cran() - set.seed(1) - pred_sex <- predict_delay_parameters(fit_sex, prep_obs_sex) - expect_s3_class(pred_sex, "lognormal_samples") - expect_s3_class(pred_sex, "data.frame") - expect_named(pred_sex, c("draw", "index", "mu", "sigma", "mean", "sd")) - expect_true(all(pred_sex$mean > 0)) - expect_true(all(pred_sex$sd > 0)) - expect_length(unique(pred_sex$index), nrow(prep_obs_sex)) - expect_length(unique(pred_sex$draw), summary(fit_sex)$total_ndraws) - pred_sex_summary <- pred_sex |> - dplyr::left_join( - dplyr::select(prep_obs_sex, index = .row_id, sex), - by = "index" - ) |> - dplyr::group_by(sex) |> - dplyr::summarise( - mu = mean(mu), - sigma = mean(sigma) + # Helper function to test sex predictions + test_sex_predictions <- function(fit, prep = prep_obs_sex) { + set.seed(1) + + pred_sex <- predict_delay_parameters(fit, prep) + expect_s3_class(pred_sex, "lognormal_samples") + expect_s3_class(pred_sex, "data.frame") + expect_named(pred_sex, c("draw", "index", "mu", "sigma", "mean", "sd")) + expect_true(all(pred_sex$mean > 0)) + expect_true(all(pred_sex$sd > 0)) + expect_length(unique(pred_sex$index), nrow(prep)) + expect_length(unique(pred_sex$draw), summary(fit)$total_ndraws) + + pred_sex_summary <- pred_sex |> + dplyr::left_join( + dplyr::select(prep, index = .row_id, sex), + by = "index" + ) |> + dplyr::group_by(sex) |> + dplyr::summarise( + mu = mean(mu), + sigma = mean(sigma) + ) + + # Correct predictions of M + expect_equal( + as.numeric(pred_sex_summary[1, c("mu", "sigma")]), + c(meanlog_m, sdlog_m), + tolerance = 0.1 ) - # Correct predictions of M - expect_equal( - as.numeric(pred_sex_summary[1, c("mu", "sigma")]), - c(meanlog_m, sdlog_m), - tolerance = 0.1 - ) + # Correction predictions of F + expect_equal( + as.numeric(pred_sex_summary[2, c("mu", "sigma")]), + c(meanlog_f, sdlog_f), + tolerance = 0.1 + ) + } - # Correction predictions of F - expect_equal( - as.numeric(pred_sex_summary[2, c("mu", "sigma")]), - c(meanlog_f, sdlog_f), - tolerance = 0.1 - ) + # Test latent and marginal models + test_sex_predictions(fit_sex) + test_sex_predictions(fit_marginal_sex, prep_marginal_obs_sex) }) test_that("add_mean_sd.lognormal_samples works with simulated lognormal distribution parameter data", { # nolint: line_length_linter. From 817738f9fd6687f9d824f48ced3a9cf8d947b82e Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 2 Dec 2024 16:42:15 +0000 Subject: [PATCH 41/62] add marginal model --- R/marginal_model.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/marginal_model.R b/R/marginal_model.R index 932239ce2..11705379d 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -77,7 +77,7 @@ assert_epidist.epidist_marginal_model <- function(data, ...) { assert_data_frame(data) assert_names(names(data), must.include = c( "pwindow", "swindow", "delay_lwr", "delay_upr", "n", - "relative_obs_time" + ".row_id", "relative_obs_time" )) assert_numeric(data$pwindow, lower = 0) assert_numeric(data$swindow, lower = 0) From 4024abae4066478de9ed9194827c3b54d11666f7 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 2 Dec 2024 16:47:40 +0000 Subject: [PATCH 42/62] use the right transform data keyword --- R/transform_data.R | 6 +++--- man/epidist_transform_data.Rd | 4 ++-- man/epidist_transform_data_model.Rd | 4 ++-- man/epidist_transform_data_model.default.Rd | 4 ++-- 4 files changed, 9 insertions(+), 9 deletions(-) diff --git a/R/transform_data.R b/R/transform_data.R index a1acf50bc..4cbc205f1 100644 --- a/R/transform_data.R +++ b/R/transform_data.R @@ -9,7 +9,7 @@ #' @param family A description of the response distribution and link function to #' be used in the model created using [epidist_family()]. #' @param formula A formula object created using [epidist_formula()]. -#' @family transform +#' @family transform_data #' @export epidist_transform_data <- function(data, family, formula, ...) { assert_epidist(data) @@ -21,7 +21,7 @@ epidist_transform_data <- function(data, family, formula, ...) { #' #' @inheritParams epidist_transform_data #' @rdname epidist_transform_data_model -#' @family transform +#' @family transform_data #' @export epidist_transform_data_model <- function(data, family, formula, ...) { UseMethod("epidist_transform_data_model") @@ -30,7 +30,7 @@ epidist_transform_data_model <- function(data, family, formula, ...) { #' Default method for transforming data for a model #' #' @inheritParams epidist_transform_data_model -#' @family transform +#' @family transform_data #' @export epidist_transform_data_model.default <- function(data, family, formula, ...) { return(data) diff --git a/man/epidist_transform_data.Rd b/man/epidist_transform_data.Rd index 54b9a3a6a..00c80c823 100644 --- a/man/epidist_transform_data.Rd +++ b/man/epidist_transform_data.Rd @@ -23,8 +23,8 @@ export it nonetheless to be transparent about what happens inside of a call to \code{\link[=epidist]{epidist()}}. } \seealso{ -Other transform: +Other transform_data: \code{\link{epidist_transform_data_model}()}, \code{\link{epidist_transform_data_model.default}()} } -\concept{transform} +\concept{transform_data} diff --git a/man/epidist_transform_data_model.Rd b/man/epidist_transform_data_model.Rd index e7b0e8302..f8daf119e 100644 --- a/man/epidist_transform_data_model.Rd +++ b/man/epidist_transform_data_model.Rd @@ -20,8 +20,8 @@ be used in the model created using \code{\link[=epidist_family]{epidist_family() The model-specific parts of an \code{epidist_transform_data()} call } \seealso{ -Other transform: +Other transform_data: \code{\link{epidist_transform_data}()}, \code{\link{epidist_transform_data_model.default}()} } -\concept{transform} +\concept{transform_data} diff --git a/man/epidist_transform_data_model.default.Rd b/man/epidist_transform_data_model.default.Rd index 06b95aa8e..ca56c63a1 100644 --- a/man/epidist_transform_data_model.default.Rd +++ b/man/epidist_transform_data_model.default.Rd @@ -20,8 +20,8 @@ be used in the model created using \code{\link[=epidist_family]{epidist_family() Default method for transforming data for a model } \seealso{ -Other transform: +Other transform_data: \code{\link{epidist_transform_data}()}, \code{\link{epidist_transform_data_model}()} } -\concept{transform} +\concept{transform_data} From ef042d3ae81942c3ad5f927dbace2ad317b16235 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 2 Dec 2024 16:51:11 +0000 Subject: [PATCH 43/62] add ... pass through to make constructors correct --- R/latent_model.R | 3 ++- R/marginal_model.R | 3 ++- man/as_epidist_marginal_model.Rd | 4 +++- man/new_epidist_latent_model.Rd | 4 +++- 4 files changed, 10 insertions(+), 4 deletions(-) diff --git a/R/latent_model.R b/R/latent_model.R index 3d5c72dbb..119e2ae67 100644 --- a/R/latent_model.R +++ b/R/latent_model.R @@ -38,10 +38,11 @@ as_epidist_latent_model.epidist_linelist_data <- function(data) { #' Class constructor for `epidist_latent_model` objects #' #' @param data An object to be set with the class `epidist_latent_model` +#' @param ... Additional arguments passed to methods. #' @returns An object of class `epidist_latent_model` #' @family latent_model #' @export -new_epidist_latent_model <- function(data) { +new_epidist_latent_model <- function(data, ...) { class(data) <- c("epidist_latent_model", class(data)) return(data) } diff --git a/R/marginal_model.R b/R/marginal_model.R index 11705379d..2ae76f4a8 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -1,9 +1,10 @@ #' Prepare marginal model to pass through to `brms` #' #' @param data A `data.frame` containing line list data +#' @param ... Additional arguments passed to methods. #' @family marginal_model #' @export -as_epidist_marginal_model <- function(data) { +as_epidist_marginal_model <- function(data, ...) { UseMethod("as_epidist_marginal_model") } diff --git a/man/as_epidist_marginal_model.Rd b/man/as_epidist_marginal_model.Rd index b0b846022..0b8432f1b 100644 --- a/man/as_epidist_marginal_model.Rd +++ b/man/as_epidist_marginal_model.Rd @@ -4,10 +4,12 @@ \alias{as_epidist_marginal_model} \title{Prepare marginal model to pass through to \code{brms}} \usage{ -as_epidist_marginal_model(data) +as_epidist_marginal_model(data, ...) } \arguments{ \item{data}{A \code{data.frame} containing line list data} + +\item{...}{Additional arguments passed to methods.} } \description{ Prepare marginal model to pass through to \code{brms} diff --git a/man/new_epidist_latent_model.Rd b/man/new_epidist_latent_model.Rd index 8658161c4..0816107d5 100644 --- a/man/new_epidist_latent_model.Rd +++ b/man/new_epidist_latent_model.Rd @@ -4,10 +4,12 @@ \alias{new_epidist_latent_model} \title{Class constructor for \code{epidist_latent_model} objects} \usage{ -new_epidist_latent_model(data) +new_epidist_latent_model(data, ...) } \arguments{ \item{data}{An object to be set with the class \code{epidist_latent_model}} + +\item{...}{Additional arguments passed to methods.} } \value{ An object of class \code{epidist_latent_model} From c2d46ee74ef440a0d85f4e63bd0e60e077618253 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 2 Dec 2024 16:59:32 +0000 Subject: [PATCH 44/62] fix .summarise_n_by_formula test so error message is as expected --- tests/testthat/test-utils.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-utils.R index e136db754..f31b452fa 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-utils.R @@ -141,7 +141,7 @@ test_that( df <- data.frame(x = 1:2, n = c(1, 2)) expect_error( .summarise_n_by_formula(df, by = "missing"), - "object 'missing' not found" + "Can't subset elements that don't exist" ) } ) From 6464e95a35dad2f4ba843d08040fff9445ee3f55 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 2 Dec 2024 17:10:33 +0000 Subject: [PATCH 45/62] drop not required .row_id --- R/marginal_model.R | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/R/marginal_model.R b/R/marginal_model.R index 2ae76f4a8..dd4c136d3 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -31,7 +31,6 @@ as_epidist_marginal_model.epidist_linelist_data <- function( orig_relative_obs_time = .data$obs_time - .data$ptime_lwr, delay_lwr = .data$stime_lwr - .data$ptime_lwr, delay_upr = .data$stime_upr - .data$ptime_lwr, - .row_id = dplyr::row_number(), n = 1 ) @@ -78,14 +77,13 @@ assert_epidist.epidist_marginal_model <- function(data, ...) { assert_data_frame(data) assert_names(names(data), must.include = c( "pwindow", "swindow", "delay_lwr", "delay_upr", "n", - ".row_id", "relative_obs_time" + "relative_obs_time" )) assert_numeric(data$pwindow, lower = 0) assert_numeric(data$swindow, lower = 0) assert_integerish(data$delay_lwr) assert_integerish(data$delay_upr) assert_numeric(data$relative_obs_time) - assert_integerish(data$.row_id, lower = 1) if (!all(abs(data$delay_upr - (data$delay_lwr + data$swindow)) < 1e-10)) { cli::cli_abort( "delay_upr must equal delay_lwr + swindow" From f42a182913ef99025195179282b713df5b40c59a Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 2 Dec 2024 17:27:09 +0000 Subject: [PATCH 46/62] check using ... properly --- R/latent_model.R | 6 ++++-- R/marginal_model.R | 3 ++- man/as_epidist_latent_model.Rd | 4 +++- man/as_epidist_latent_model.epidist_linelist_data.Rd | 4 +++- man/as_epidist_marginal_model.epidist_linelist_data.Rd | 4 +++- 5 files changed, 15 insertions(+), 6 deletions(-) diff --git a/R/latent_model.R b/R/latent_model.R index 119e2ae67..de21666aa 100644 --- a/R/latent_model.R +++ b/R/latent_model.R @@ -1,9 +1,10 @@ #' Convert an object to an `epidist_latent_model` object #' #' @param data An object to be converted to the class `epidist_latent_model` +#' @param ... Additional arguments passed to methods. #' @family latent_model #' @export -as_epidist_latent_model <- function(data) { +as_epidist_latent_model <- function(data, ...) { UseMethod("as_epidist_latent_model") } @@ -11,11 +12,12 @@ as_epidist_latent_model <- function(data) { #' The latent model method for `epidist_linelist_data` objects #' #' @param data An `epidist_linelist_data` object +#' @param ... Not used in this method. #' @method as_epidist_latent_model epidist_linelist_data #' @family latent_model #' @autoglobal #' @export -as_epidist_latent_model.epidist_linelist_data <- function(data) { +as_epidist_latent_model.epidist_linelist_data <- function(data, ...) { assert_epidist(data) data <- data |> mutate( diff --git a/R/marginal_model.R b/R/marginal_model.R index dd4c136d3..ee8a20d9b 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -15,12 +15,13 @@ as_epidist_marginal_model <- function(data, ...) { #' relative observation times to Inf. Observation times greater than #' `obs_time_threshold` times the maximum delay will be set to Inf to improve #' model efficiency. Default is 2. +#' @param ... Not used in this method. #' @method as_epidist_marginal_model epidist_linelist_data #' @family marginal_model #' @autoglobal #' @export as_epidist_marginal_model.epidist_linelist_data <- function( - data, obs_time_threshold = 2) { + data, obs_time_threshold = 2, ...) { assert_epidist(data) data <- data |> diff --git a/man/as_epidist_latent_model.Rd b/man/as_epidist_latent_model.Rd index 9f9a419aa..f63e925c3 100644 --- a/man/as_epidist_latent_model.Rd +++ b/man/as_epidist_latent_model.Rd @@ -4,10 +4,12 @@ \alias{as_epidist_latent_model} \title{Convert an object to an \code{epidist_latent_model} object} \usage{ -as_epidist_latent_model(data) +as_epidist_latent_model(data, ...) } \arguments{ \item{data}{An object to be converted to the class \code{epidist_latent_model}} + +\item{...}{Additional arguments passed to methods.} } \description{ Convert an object to an \code{epidist_latent_model} object diff --git a/man/as_epidist_latent_model.epidist_linelist_data.Rd b/man/as_epidist_latent_model.epidist_linelist_data.Rd index 3b91956df..5e8e86af8 100644 --- a/man/as_epidist_latent_model.epidist_linelist_data.Rd +++ b/man/as_epidist_latent_model.epidist_linelist_data.Rd @@ -4,10 +4,12 @@ \alias{as_epidist_latent_model.epidist_linelist_data} \title{The latent model method for \code{epidist_linelist_data} objects} \usage{ -\method{as_epidist_latent_model}{epidist_linelist_data}(data) +\method{as_epidist_latent_model}{epidist_linelist_data}(data, ...) } \arguments{ \item{data}{An \code{epidist_linelist_data} object} + +\item{...}{Not used in this method.} } \description{ The latent model method for \code{epidist_linelist_data} objects diff --git a/man/as_epidist_marginal_model.epidist_linelist_data.Rd b/man/as_epidist_marginal_model.epidist_linelist_data.Rd index 2754f4058..3eac380c6 100644 --- a/man/as_epidist_marginal_model.epidist_linelist_data.Rd +++ b/man/as_epidist_marginal_model.epidist_linelist_data.Rd @@ -4,7 +4,7 @@ \alias{as_epidist_marginal_model.epidist_linelist_data} \title{The marginal model method for \code{epidist_linelist_data} objects} \usage{ -\method{as_epidist_marginal_model}{epidist_linelist_data}(data, obs_time_threshold = 2) +\method{as_epidist_marginal_model}{epidist_linelist_data}(data, obs_time_threshold = 2, ...) } \arguments{ \item{data}{An \code{epidist_linelist_data} object} @@ -13,6 +13,8 @@ relative observation times to Inf. Observation times greater than \code{obs_time_threshold} times the maximum delay will be set to Inf to improve model efficiency. Default is 2.} + +\item{...}{Not used in this method.} } \description{ The marginal model method for \code{epidist_linelist_data} objects From dfc73e013b2ca606a77024f395da1e6d4a95effa Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 2 Dec 2024 17:47:45 +0000 Subject: [PATCH 47/62] make the progress messages prettier for reducing data complexit: --- R/marginal_model.R | 19 +++++++++++-------- ...st_marginal_model.epidist_linelist_data.Rd | 3 ++- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/R/marginal_model.R b/R/marginal_model.R index ee8a20d9b..9d3722c20 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -14,7 +14,8 @@ as_epidist_marginal_model <- function(data, ...) { #' @param obs_time_threshold Ratio used to determine threshold for setting #' relative observation times to Inf. Observation times greater than #' `obs_time_threshold` times the maximum delay will be set to Inf to improve -#' model efficiency. Default is 2. +#' model efficiency by reducing the number of unique observation times. +#' Default is 2. #' @param ... Not used in this method. #' @method as_epidist_marginal_model epidist_linelist_data #' @family marginal_model @@ -167,13 +168,15 @@ epidist_transform_data_model.epidist_marginal_model <- function( new_epidist_marginal_model() n_rows_after <- nrow(trans_data) if (n_rows_before > n_rows_after) { - cli::cli_inform("Data summarised by unique combinations of:") + cli::cli_inform(c( + "i" = "Data summarised by unique combinations of:" # nolint + )) formula_vars <- setdiff(names(trans_data), c(required_cols, "n")) if (length(formula_vars) > 0) { - cli::cli_inform( - paste0("* Formula variables: {.code {formula_vars}}") - ) + cli::cli_inform(c( + "*" = "Formula variables: {.code {formula_vars}}" + )) } cli::cli_inform(paste0( @@ -181,9 +184,9 @@ epidist_transform_data_model.epidist_marginal_model <- function( "and primary censoring window" )) - cli::cli_inform(paste0( - "i" = "Reduced from {n_rows_before} to {n_rows_after} rows. ", # nolint - "This should improve model efficiency with no loss of information." + cli::cli_inform(c( + "!" = paste("Reduced from", n_rows_before, "to", n_rows_after, "rows."), + "i" = "This should improve model efficiency with no loss of information." # nolint )) } diff --git a/man/as_epidist_marginal_model.epidist_linelist_data.Rd b/man/as_epidist_marginal_model.epidist_linelist_data.Rd index 3eac380c6..2277bbe09 100644 --- a/man/as_epidist_marginal_model.epidist_linelist_data.Rd +++ b/man/as_epidist_marginal_model.epidist_linelist_data.Rd @@ -12,7 +12,8 @@ \item{obs_time_threshold}{Ratio used to determine threshold for setting relative observation times to Inf. Observation times greater than \code{obs_time_threshold} times the maximum delay will be set to Inf to improve -model efficiency. Default is 2.} +model efficiency by reducing the number of unique observation times. +Default is 2.} \item{...}{Not used in this method.} } From 6c529f1c6ef639376830fbcd285ec94e707f096c Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 2 Dec 2024 18:21:42 +0000 Subject: [PATCH 48/62] check post process tests again --- tests/testthat/test-postprocess.R | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/testthat/test-postprocess.R b/tests/testthat/test-postprocess.R index 5d898ad9b..3ccfe1d19 100644 --- a/tests/testthat/test-postprocess.R +++ b/tests/testthat/test-postprocess.R @@ -28,7 +28,8 @@ test_that("predict_delay_parameters accepts newdata arguments and prediction by # Helper function to test sex predictions test_sex_predictions <- function(fit, prep = prep_obs_sex) { set.seed(1) - + prep <- prep |> + dplyr::mutate(.row_id = dplyr::row_number()) pred_sex <- predict_delay_parameters(fit, prep) expect_s3_class(pred_sex, "lognormal_samples") expect_s3_class(pred_sex, "data.frame") From 9c0d9a293883b0e6154438b3794c4933ffa93198 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 2 Dec 2024 18:34:21 +0000 Subject: [PATCH 49/62] add a test for the specific transform data method --- tests/testthat/test-marginal_model.R | 69 ++++++++++++++++++++++++++-- 1 file changed, 64 insertions(+), 5 deletions(-) diff --git a/tests/testthat/test-marginal_model.R b/tests/testthat/test-marginal_model.R index 34085284b..aeaf43c30 100644 --- a/tests/testthat/test-marginal_model.R +++ b/tests/testthat/test-marginal_model.R @@ -4,16 +4,13 @@ test_that("as_epidist_marginal_model.epidist_linelist_data with default settings expect_s3_class(prep_marginal_obs, "epidist_marginal_model") }) -test_that("as_epidist_marginal_model.epidist_linelist_data when passed incorrect inputs", { # nolint: line_length_linter. +test_that("as_epidist_marginal_model.epidist_linelist_data errors when passed incorrect inputs", { # nolint: line_length_linter. expect_error(as_epidist_marginal_model(list())) expect_error(as_epidist_marginal_model(sim_obs[, 1])) }) # Make this data available for other tests -family_lognormal <- epidist_family( - prep_marginal_obs, - family = brms::lognormal() -) +family_lognormal <- epidist_family(prep_marginal_obs, family = lognormal()) test_that("is_epidist_marginal_model returns TRUE for correct input", { # nolint: line_length_linter. expect_true(is_epidist_marginal_model(prep_marginal_obs)) @@ -46,3 +43,65 @@ test_that("assert_epidist.epidist_marginal_model returns FALSE for incorrect inp assert_epidist(x) }) }) + +test_that("epidist_stancode.epidist_marginal_model produces valid stanvars", { # nolint: line_length_linter. + epidist_family <- epidist_family(prep_marginal_obs) + epidist_formula <- epidist_formula( + prep_marginal_obs, epidist_family, + formula = bf(mu ~ 1) + ) + stancode <- epidist_stancode( + prep_marginal_obs, + family = epidist_family, formula = epidist_formula + ) + expect_s3_class(stancode, "stanvars") +}) + +test_that("epidist_transform_data_model.epidist_marginal_model correctly transforms data and messages", { # nolint: line_length_linter. + family <- epidist_family(prep_marginal_obs, family = lognormal()) + formula <- epidist_formula( + prep_marginal_obs, + formula = bf(mu ~ 1), + family = family + ) + expect_no_message( + expect_message( + expect_message( + expect_message( + epidist_transform_data_model( + prep_marginal_obs, + family = family, + formula = formula + ), + "Reduced from 500 to 144 rows." + ), + "Data summarised by unique combinations of:" + ), + "Model variables" + ) + ) + + family <- epidist_family(prep_marginal_obs, family = lognormal()) + formula <- epidist_formula( + prep_marginal_obs, + formula = bf(mu ~ 1 + ptime_lwr), + family = family + ) + expect_message( + expect_message( + expect_message( + expect_message( + epidist_transform_data_model( + prep_marginal_obs, + family = family, + formula = formula + ), + "Reduced from 500 to 144 rows." + ), + "Data summarised by unique combinations of:" + ), + "Model variables" + ), + "ptime_lwr" + ) +}) From 0712992c80d5cb24a38efa3d9ecde50bf385db67 Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 2 Dec 2024 18:38:25 +0000 Subject: [PATCH 50/62] add some tests for the generic transform data method --- tests/test-transform_data.R | 39 +++++++++++++++++++++++++++++++++++++ 1 file changed, 39 insertions(+) create mode 100644 tests/test-transform_data.R diff --git a/tests/test-transform_data.R b/tests/test-transform_data.R new file mode 100644 index 000000000..a970d8652 --- /dev/null +++ b/tests/test-transform_data.R @@ -0,0 +1,39 @@ +test_that( + "epidist_transform_data with default settings returns data unchanged", + { + family <- epidist_family(prep_obs, family = lognormal()) + formula <- epidist_formula(prep_obs, family = family, formula = bf(mu ~ 1)) + + transformed <- epidist_transform_data(prep_obs, family, formula) + expect_identical(transformed, prep_obs) + } +) + +test_that("epidist_transform_data errors when passed incorrect inputs", { + family <- epidist_family(prep_obs, family = lognormal()) + formula <- epidist_formula(prep_obs, family = family, formula = bf(mu ~ 1)) + + expect_error(epidist_transform_data(list(), family, formula)) +}) + +test_that("epidist_transform_data_model.default returns data unchanged", { + family <- epidist_family(prep_obs, family = lognormal()) + formula <- epidist_formula(prep_obs, family = family, formula = bf(mu ~ 1)) + + transformed <- epidist_transform_data_model(prep_obs, family, formula) + expect_identical(transformed, prep_obs) +}) + +test_that("epidist_transform_data works with different model types", { + family <- epidist_family(prep_obs, family = lognormal()) + formula <- epidist_formula(prep_obs, family = family, formula = bf(mu ~ 1)) + + expect_identical( + epidist_transform_data(prep_naive_obs, family, formula), + prep_naive_obs + ) + expect_identical( + epidist_transform_data(prep_obs_gamma, family, formula), + prep_obs_gamma + ) +}) From 2aabfd7ddf9055895862d80d9fcbaa64431dd60a Mon Sep 17 00:00:00 2001 From: Sam Date: Mon, 2 Dec 2024 20:48:56 +0000 Subject: [PATCH 51/62] put transform data tests in the correct folder --- tests/{ => testthat}/test-transform_data.R | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/{ => testthat}/test-transform_data.R (100%) diff --git a/tests/test-transform_data.R b/tests/testthat/test-transform_data.R similarity index 100% rename from tests/test-transform_data.R rename to tests/testthat/test-transform_data.R From ff89e3b57354ce6515abf1cba6d11e9aeba06746 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 3 Dec 2024 10:14:13 +0000 Subject: [PATCH 52/62] add a news update --- NEWS.md | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/NEWS.md b/NEWS.md index 63d0ca7c8..4355e83d5 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,19 +2,26 @@ Development version of `epidist`. +## Models + +- Added a marginalised likelihood model based on `primarycensored`. This can be specified using `as_epidist_marginal_model()`. This is currently limited to Weibull, Log-Normal, and Gamma distributions with uniform primary censoring but this will be generalised in future releases. See #426. +- Added user settable primary event priors to the latent model. See #474. +- Added a marginalised likelihood to the latent model. See #474. + ## Package - Remove the default method for `epidist()`. See #473. - Added `enforce_presence` argument to `epidist_prior()` to allow for priors to be specified if they do not match existing parameters. See #474. - Added a `merge` argument to `epidist_prior()` to allow for not merging user and package priors. See #474. -- Added user settable primary event priors to the latent model. See #474. -- Added a marginalised likelihood to the latent model. See #474. - Generalised the stan reparametrisation feature to work across all distributions without manual specification by generating stan code with `brms` and then extracting the reparameterisation. See #474. +- Added a `transform_data` s3 method to allow for data to be transformed for specific models. This is specifically useful for the marginal model at the moment as it allows reducing the data to its unique strata. See #474. ## Documentation - Brings the README into line with `epinowcast` standards. See #467. +- Switched over to using the marginal model everywhere. See #426. +- Added helper functions for new variables to avoid code duplication in vignettes. See #426. # epidist 0.1.0 From 3a3aed272b158c8b09eec8920ed5bd5084e353e8 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 3 Dec 2024 10:44:04 +0000 Subject: [PATCH 53/62] change vignette language to talk about marginal model --- .Rbuildignore | 1 + vignettes/approx-inference.Rmd | 4 ++-- vignettes/ebola.Rmd | 4 ++-- vignettes/faq.Rmd | 5 +++-- 4 files changed, 8 insertions(+), 6 deletions(-) diff --git a/.Rbuildignore b/.Rbuildignore index be0e79bdb..151bca3d6 100644 --- a/.Rbuildignore +++ b/.Rbuildignore @@ -17,4 +17,5 @@ ^pkgdown$ ^vignettes/approx-inference\.Rmd$ ^vignettes/ebola\.Rmd$ +^vignettes/faq\.Rmd$ ^\.lintr$ diff --git a/vignettes/approx-inference.Rmd b/vignettes/approx-inference.Rmd index bb91d8eee..526ca64d2 100644 --- a/vignettes/approx-inference.Rmd +++ b/vignettes/approx-inference.Rmd @@ -135,7 +135,7 @@ obs_cens_trunc_samp <- simulate_gillespie(seed = 101) |> slice_sample(n = sample_size, replace = FALSE) ``` -We now prepare the data for fitting with the latent individual model, and perform inference with HMC: +We now prepare the data for fitting with the marginal model, and perform inference with HMC: ```{r results='hide'} linelist_data <- as_epidist_linelist_data( @@ -155,7 +155,7 @@ time_hmc <- proc.time() - t Note that for clarity above we specify `algorithm = "sampling"`, but if you were to call `epidist(data = data)` the result would be the same since `"sampling"` (i.e. HMC) is the default value for the `algorithm` argument. -Now, we fit^[Note that in this section, and above for the MCMC, the output of the call is hidden, but if you were to call these functions yourself they would display information about the fitting procedure as it occurs] the same latent individual model using each method in Section \@ref(other). +Now, we fit^[Note that in this section, and above for the MCMC, the output of the call is hidden, but if you were to call these functions yourself they would display information about the fitting procedure as it occurs] the same marginal model using each method in Section \@ref(other). To match the four Markov chains of length 1000 in HMC above, we then draw 4000 samples from each approximate posterior. ```{r results='hide'} diff --git a/vignettes/ebola.Rmd b/vignettes/ebola.Rmd index c0cded2f6..9e1b176c5 100644 --- a/vignettes/ebola.Rmd +++ b/vignettes/ebola.Rmd @@ -216,14 +216,14 @@ Second, because we also did not supply an observation time column (`obs_date`), ## Model fitting -To prepare the data for use with the latent individual model, we define the data as being a `epidist_latent_model` model object: +To prepare the data for use with the marginal model, we define the data as being a `epidist_marginal_model` model object: ```{r} obs_prep <- as_epidist_marginal_model(linelist_data) head(obs_prep) ``` -Now we are ready to fit the latent individual model. +Now we are ready to fit the marginal model. ### Intercept-only model diff --git a/vignettes/faq.Rmd b/vignettes/faq.Rmd index 3343eaf24..082a036c8 100644 --- a/vignettes/faq.Rmd +++ b/vignettes/faq.Rmd @@ -61,12 +61,13 @@ linelist_data <- as_epidist_linelist_data( obs_cens_trunc_samp$stime_upr, obs_time = obs_cens_trunc_samp$obs_time ) -data <- as_epidist_latent_model(linelist_data) +data <- as_epidist_marginal_model(linelist_data) fit <- epidist( data, formula = mu ~ 1, - seed = 1 + seed = 1, + backend = "cmdstanr" ) ``` From 2d1aef47fbe79e9f4694c9abca2a6acecbf16582 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 3 Dec 2024 11:00:10 +0000 Subject: [PATCH 54/62] update the FAQ to use the marginal variables --- vignettes/faq.Rmd | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/vignettes/faq.Rmd b/vignettes/faq.Rmd index 082a036c8..a24aef918 100644 --- a/vignettes/faq.Rmd +++ b/vignettes/faq.Rmd @@ -67,6 +67,9 @@ fit <- epidist( data, formula = mu ~ 1, seed = 1, + chains = 2, + cores = 2, + refresh = ifelse(interactive(), 250, 0), backend = "cmdstanr" ) ``` @@ -152,7 +155,8 @@ fit_ppc <- epidist( formula = mu ~ 1, family = lognormal(), sample_prior = "only", - seed = 1 + seed = 1, + backend = "cmdstanr" ) ``` @@ -218,8 +222,10 @@ To see these functions demonstrated in a vignette, see ["Advanced features with As a short example, to generate 4000 predictions (equal to the number of draws) of the delay that would be observed with a double censored observation process (in which the primary and secondary censoring windows are both one) then: ```{r} -draws_pmf <- data.frame(relative_obs_time = 1000, pwindow = 1, swindow = 1) |> - add_predicted_draws(fit, ndraws = 4000) +draws_pmf <- tibble::tibble( + relative_obs_time = Inf, pwindow = 1, swindow = 1, delay_upr = NA +) |> + add_predicted_draws(fit, ndraws = 2000) ggplot(draws_pmf, aes(x = .prediction)) + geom_bar(aes(y = after_stat(count / sum(count)))) + @@ -228,10 +234,7 @@ ggplot(draws_pmf, aes(x = .prediction)) + theme_minimal() ``` -Importantly, this functionality is only available for `epidist` models using custom `brms` families that have `posterior_predict` and `posterior_epred` methods implemented. -For example, for the `epidist_latent_model` model, currently methods are implemented for the [lognormal](https://github.com/epinowcast/epidist/blob/main/R/latent_lognormal.R) and [gamma](https://github.com/epinowcast/epidist/blob/main/R/latent_gamma.R) families. -If you are using another family, consider [submitting a pull request](https://github.com/epinowcast/epidist/pulls) to implement these methods! -In doing so, you may find it useful to use the [`primarycensored`](https://primarycensored.epinowcast.org/) package. +Importantly, this functionality is only available for `epidist` models using `brms` families that have a `log_lik_censor` method implemented internally in `brms`. If you are using another family, consider [submitting a pull request](https://github.com/epinowcast/epidist/pulls) to implement these methods! # How can I use the `cmdstanr` backend? From 2a6118214ac8168f1776c202a5b9729e893f30bb Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 3 Dec 2024 11:04:28 +0000 Subject: [PATCH 55/62] call it transformed_data not trans_data --- R/epidist.R | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/R/epidist.R b/R/epidist.R index 8825edea9..a4718c229 100644 --- a/R/epidist.R +++ b/R/epidist.R @@ -38,18 +38,20 @@ epidist <- function(data, formula = mu ~ 1, epidist_formula <- epidist_formula( data = data, family = epidist_family, formula = formula ) - trans_data <- epidist_transform_data(data, epidist_family, epidist_formula) + transformed_data <- epidist_transform_data( + data, epidist_family, epidist_formula + ) epidist_prior <- epidist_prior( - data = trans_data, family = epidist_family, + data = transformed_data, family = epidist_family, formula = epidist_formula, prior, merge = merge_priors ) epidist_stancode <- epidist_stancode( - data = trans_data, family = epidist_family, formula = epidist_formula + data = transformed_data, family = epidist_family, formula = epidist_formula ) fit <- fn( formula = epidist_formula, family = epidist_family, prior = epidist_prior, - stanvars = epidist_stancode, data = trans_data, ... + stanvars = epidist_stancode, data = transformed_data, ... ) class(fit) <- c(class(fit), "epidist_fit") return(fit) From 4f080c4adb9326834710e18726e1e417c1eb861a Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 3 Dec 2024 11:23:29 +0000 Subject: [PATCH 56/62] change the error message to make it clear its a epidist limitation --- R/marginal_model.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/R/marginal_model.R b/R/marginal_model.R index 9d3722c20..21837d506 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -226,7 +226,7 @@ epidist_stancode.epidist_marginal_model <- function( dist_id <- 3 } else { cli_abort(c( - "!" = "No solution available in primarycensored for this family" + "!" = "epidist does not currently support this family for the marginal model" # nolint )) } From 6293762919b86f5c5c7de2439865eb550a71fdf3 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 3 Dec 2024 11:27:23 +0000 Subject: [PATCH 57/62] update stan docs --- inst/stan/latent_model/functions.stan | 8 ++++---- inst/stan/marginal_model/functions.stan | 9 ++++----- 2 files changed, 8 insertions(+), 9 deletions(-) diff --git a/inst/stan/latent_model/functions.stan b/inst/stan/latent_model/functions.stan index 3134f7f95..eac1bd921 100644 --- a/inst/stan/latent_model/functions.stan +++ b/inst/stan/latent_model/functions.stan @@ -3,11 +3,11 @@ * * This function is designed to be read into R where: * - 'family' is replaced with the target distribution (e.g., 'lognormal') - * - 'dpars_A' is replaced with multiple parameters in the format + * - 'dpars_A' is replaced with multiple distribution parameters in the format * "vector|real paramname1, vector|real paramname2, ..." depending on whether - * each parameter has a model. This includes distribution parameters. - * - 'dpars_B' is replaced with the same parameters as dpars_A but with window - * indices removed. + * each parameter has a model. + * - 'dpars_B' is replaced with the same parameters as dpars_A but + * reparameterised according to the brms parameterisation for Stan. * * @param y Vector of observed values (delays) * @param dpars_A Distribution parameters (replaced via regex) diff --git a/inst/stan/marginal_model/functions.stan b/inst/stan/marginal_model/functions.stan index 3c1f05cf1..ee55dae2c 100644 --- a/inst/stan/marginal_model/functions.stan +++ b/inst/stan/marginal_model/functions.stan @@ -3,11 +3,10 @@ * * This function is designed to be read into R where: * - 'family' is replaced with the target distribution (e.g., 'lognormal') - * - 'dpars_A' is replaced with multiple parameters in the format - * "vector|real paramname1, vector|real paramname2, ..." depending on whether - * each parameter has a model. This includes distribution parameters. - * - 'dpars_B' is replaced with the same parameters as dpars_A but with window - * indices removed. + * - 'dpars_A' is replaced with multiple distribution parameters in the format + * "real paramname1, real paramname2, ...". + * - 'dpars_B' is replaced with the same parameters as dpars_A but + * reparameterised according to the brms parameterisation for Stan. * * @param y Real value of observed delay * @param dpars_A Distribution parameters (replaced via regex) From 7062e52f72dc1fab0ca4f9f6c4a0919aed313a44 Mon Sep 17 00:00:00 2001 From: Sam Abbott Date: Tue, 3 Dec 2024 13:15:09 +0000 Subject: [PATCH 58/62] Update NEWS.md --- NEWS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 4355e83d5..0e88ba98c 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,7 +4,7 @@ Development version of `epidist`. ## Models -- Added a marginalised likelihood model based on `primarycensored`. This can be specified using `as_epidist_marginal_model()`. This is currently limited to Weibull, Log-Normal, and Gamma distributions with uniform primary censoring but this will be generalised in future releases. See #426. +- Added a marginalised likelihood model based on `primarycensored`. This can be specified using `as_epidist_marginal_model()`. This is currently limited to Weibull, log-normal, and gamma distributions with uniform primary censoring but this will be generalised in future releases. See #426. - Added user settable primary event priors to the latent model. See #474. - Added a marginalised likelihood to the latent model. See #474. From 79f0a881eb9680bf50804f30663a74796a2f8ce0 Mon Sep 17 00:00:00 2001 From: Sam Abbott Date: Tue, 3 Dec 2024 13:15:19 +0000 Subject: [PATCH 59/62] Update NEWS.md Co-authored-by: Adam Howes --- NEWS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 0e88ba98c..6aa77a108 100644 --- a/NEWS.md +++ b/NEWS.md @@ -15,7 +15,7 @@ Development version of `epidist`. specified if they do not match existing parameters. See #474. - Added a `merge` argument to `epidist_prior()` to allow for not merging user and package priors. See #474. - Generalised the stan reparametrisation feature to work across all distributions without manual specification by generating stan code with `brms` and then extracting the reparameterisation. See #474. -- Added a `transform_data` s3 method to allow for data to be transformed for specific models. This is specifically useful for the marginal model at the moment as it allows reducing the data to its unique strata. See #474. +- Added a `transform_data` S3 method to allow for data to be transformed for specific models. This is specifically useful for the marginal model at the moment as it allows reducing the data to its unique strata. See #474. ## Documentation From 292cfb380726e09bd36e457e5c7c84d6ffb0a31b Mon Sep 17 00:00:00 2001 From: Sam Abbott Date: Tue, 3 Dec 2024 13:15:31 +0000 Subject: [PATCH 60/62] Update NEWS.md Co-authored-by: Adam Howes --- NEWS.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/NEWS.md b/NEWS.md index 6aa77a108..14395a9ad 100644 --- a/NEWS.md +++ b/NEWS.md @@ -20,7 +20,7 @@ Development version of `epidist`. ## Documentation - Brings the README into line with `epinowcast` standards. See #467. -- Switched over to using the marginal model everywhere. See #426. +- Switched over to using the marginal model as default in documentation. See #426. - Added helper functions for new variables to avoid code duplication in vignettes. See #426. # epidist 0.1.0 From ef17a1af79dd0cf3c8d2fee92f32c96aef931af7 Mon Sep 17 00:00:00 2001 From: Sam Abbott Date: Tue, 3 Dec 2024 13:18:23 +0000 Subject: [PATCH 61/62] Update inst/stan/latent_model/functions.stan Co-authored-by: Adam Howes --- inst/stan/latent_model/functions.stan | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/inst/stan/latent_model/functions.stan b/inst/stan/latent_model/functions.stan index eac1bd921..fba796829 100644 --- a/inst/stan/latent_model/functions.stan +++ b/inst/stan/latent_model/functions.stan @@ -3,7 +3,7 @@ * * This function is designed to be read into R where: * - 'family' is replaced with the target distribution (e.g., 'lognormal') - * - 'dpars_A' is replaced with multiple distribution parameters in the format + * - 'dpars_A' is replaced with multiple distribution parameters in the format * "vector|real paramname1, vector|real paramname2, ..." depending on whether * each parameter has a model. * - 'dpars_B' is replaced with the same parameters as dpars_A but From 0ce7617cec1110b206870b40ba95da94221ee0f4 Mon Sep 17 00:00:00 2001 From: Sam Abbott Date: Tue, 3 Dec 2024 13:19:07 +0000 Subject: [PATCH 62/62] Update setup.R --- tests/testthat/setup.R | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/testthat/setup.R b/tests/testthat/setup.R index f9d76a014..d6d056826 100644 --- a/tests/testthat/setup.R +++ b/tests/testthat/setup.R @@ -160,7 +160,7 @@ if (not_on_cran()) { fit_marginal_sex <- suppressMessages(epidist( data = prep_marginal_obs_sex, formula = bf(mu ~ 1 + sex, sigma ~ 1 + sex), - seed = 1, silent = 2, refresh = 50, + seed = 1, silent = 2, refresh = 0, cores = 2, chains = 2, backend = "cmdstanr" )) }