diff --git a/R/marginal_model.R b/R/marginal_model.R index 7e389c9ab..4a80e6b52 100644 --- a/R/marginal_model.R +++ b/R/marginal_model.R @@ -80,7 +80,7 @@ is_epidist_marginal_model <- function(data) { epidist_family_model.epidist_marginal_model <- function( data, family, ...) { custom_family <- brms::custom_family( - "primarycensored_wrapper", + paste0("marginal_", family$family), dpars = family$dpars, links = c(family$link, family$other_links), lb = c(NA, as.numeric(lapply(family$other_bounds, "[[", "lb"))), @@ -113,7 +113,10 @@ epidist_formula_model.epidist_marginal_model <- function( #' @family marginal_model #' @autoglobal #' @export -epidist_stancode.epidist_marginal_model <- function(data, ...) { +epidist_stancode.epidist_marginal_model <- function( + data, + family = epidist_family(data), + formula = epidist_formula(data), ...) { assert_epidist(data) stanvars_version <- .version_stanvar() @@ -123,6 +126,58 @@ epidist_stancode.epidist_marginal_model <- function(data, ...) { scode = .stan_chunk(file.path("marginal_model", "functions.stan")) ) + family_name <- gsub("marginal_", "", family$name, fixed = TRUE) + + stanvars_functions[[1]]$scode <- gsub( + "family", family_name, stanvars_functions[[1]]$scode, + fixed = TRUE + ) + + # Can probably be extended to non-analytic solution families but for now + if (family_name == "lognormal") { + dist_id <- 1 + } else if (family_name == "gamma") { + dist_id <- 2 + } else if (family_name == "weibell") { + dist_id <- 3 + } else { + cli_abort(c( + "!" = "No analytic solution available in primarycensored for this family" + )) + } + + # Replace the dist_id passed to primarycensored + stanvars_functions[[1]]$scode <- gsub( + "input_dist_id", dist_id, stanvars_functions[[1]]$scode, + fixed = TRUE + ) + + # Inject vector or real depending if there is a model for each dpar + vector_real <- purrr::map_vec(family$dpars, function(dpar) { + return("real") + }) + + stanvars_functions[[1]]$scode <- gsub( + "dpars_A", + toString(paste0(vector_real, " ", family$dpars)), + stanvars_functions[[1]]$scode, + fixed = TRUE + ) + + # Need to consider whether any reparametrisation is required here for input + # input primarycensored. Assume not for now. Also assume two dpars + stanvars_functions[[1]]$scode <- gsub( + "dpars_1", family$dpars[1], + stanvars_functions[[1]]$scode, + fixed = TRUE + ) + + stanvars_functions[[1]]$scode <- gsub( + "dpars_2", family$dpars[2], + stanvars_functions[[1]]$scode, + fixed = TRUE + ) + pcd_stanvars_functions <- brms::stanvar( block = "functions", scode = primarycensored::pcd_load_stan_functions() diff --git a/inst/stan/marginal_model/functions.stan b/inst/stan/marginal_model/functions.stan index fc7e9ee92..5df25b5b9 100644 --- a/inst/stan/marginal_model/functions.stan +++ b/inst/stan/marginal_model/functions.stan @@ -1,10 +1,18 @@ -real primarycensored_wrapper_lpmf(data int d, real mu, real sigma, data real pwindow) { - int dist_id = 1; // lognormal +// This function is a wrapper to primarycensored_lpmf +// Here the strings +// * family +// * dpars_A +// * dpars_1 +// * dpars_2 +// are/have been replaced using regex + +real marginal_family_lpmf(data int d, dpars_A, data real pwindow) { + int dist_id = input_dist_id; array[2] real params; - params[1] = mu; - params[2] = sigma; + params[1] = dpars_1; + params[2] = dpars_2; int d_upper = d + 1; - int primary_id = 1; // Uniform + int primary_id = 1; // Fixed as uniform array[0] real primary_params; return primarycensored_lpmf(d | dist_id, params, pwindow, d_upper, positive_infinity(), primary_id, primary_params); }