Skip to content

Commit

Permalink
Regex version of marginal model
Browse files Browse the repository at this point in the history
  • Loading branch information
athowes committed Nov 25, 2024
1 parent 26e7b4b commit 0d71938
Show file tree
Hide file tree
Showing 2 changed files with 70 additions and 7 deletions.
59 changes: 57 additions & 2 deletions R/marginal_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ is_epidist_marginal_model <- function(data) {
epidist_family_model.epidist_marginal_model <- function(
data, family, ...) {
custom_family <- brms::custom_family(
"primarycensored_wrapper",
paste0("marginal_", family$family),
dpars = family$dpars,
links = c(family$link, family$other_links),
lb = c(NA, as.numeric(lapply(family$other_bounds, "[[", "lb"))),
Expand Down Expand Up @@ -113,7 +113,10 @@ epidist_formula_model.epidist_marginal_model <- function(
#' @family marginal_model
#' @autoglobal
#' @export
epidist_stancode.epidist_marginal_model <- function(data, ...) {
epidist_stancode.epidist_marginal_model <- function(
data,
family = epidist_family(data),
formula = epidist_formula(data), ...) {
assert_epidist(data)

stanvars_version <- .version_stanvar()
Expand All @@ -123,6 +126,58 @@ epidist_stancode.epidist_marginal_model <- function(data, ...) {
scode = .stan_chunk(file.path("marginal_model", "functions.stan"))
)

family_name <- gsub("marginal_", "", family$name, fixed = TRUE)

stanvars_functions[[1]]$scode <- gsub(
"family", family_name, stanvars_functions[[1]]$scode,
fixed = TRUE
)

# Can probably be extended to non-analytic solution families but for now
if (family_name == "lognormal") {
dist_id <- 1
} else if (family_name == "gamma") {
dist_id <- 2
} else if (family_name == "weibell") {
dist_id <- 3
} else {
cli_abort(c(
"!" = "No analytic solution available in primarycensored for this family"
))
}

# Replace the dist_id passed to primarycensored
stanvars_functions[[1]]$scode <- gsub(
"input_dist_id", dist_id, stanvars_functions[[1]]$scode,
fixed = TRUE
)

# Inject vector or real depending if there is a model for each dpar
vector_real <- purrr::map_vec(family$dpars, function(dpar) {
return("real")
})

stanvars_functions[[1]]$scode <- gsub(
"dpars_A",
toString(paste0(vector_real, " ", family$dpars)),
stanvars_functions[[1]]$scode,
fixed = TRUE
)

# Need to consider whether any reparametrisation is required here for input
# input primarycensored. Assume not for now. Also assume two dpars
stanvars_functions[[1]]$scode <- gsub(
"dpars_1", family$dpars[1],
stanvars_functions[[1]]$scode,
fixed = TRUE
)

stanvars_functions[[1]]$scode <- gsub(
"dpars_2", family$dpars[2],
stanvars_functions[[1]]$scode,
fixed = TRUE
)

pcd_stanvars_functions <- brms::stanvar(
block = "functions",
scode = primarycensored::pcd_load_stan_functions()
Expand Down
18 changes: 13 additions & 5 deletions inst/stan/marginal_model/functions.stan
Original file line number Diff line number Diff line change
@@ -1,10 +1,18 @@
real primarycensored_wrapper_lpmf(data int d, real mu, real sigma, data real pwindow) {
int dist_id = 1; // lognormal
// This function is a wrapper to primarycensored_lpmf
// Here the strings
// * family
// * dpars_A
// * dpars_1
// * dpars_2
// are/have been replaced using regex

real marginal_family_lpmf(data int d, dpars_A, data real pwindow) {
int dist_id = input_dist_id;
array[2] real params;
params[1] = mu;
params[2] = sigma;
params[1] = dpars_1;
params[2] = dpars_2;
int d_upper = d + 1;
int primary_id = 1; // Uniform
int primary_id = 1; // Fixed as uniform
array[0] real primary_params;
return primarycensored_lpmf(d | dist_id, params, pwindow, d_upper, positive_infinity(), primary_id, primary_params);
}

0 comments on commit 0d71938

Please sign in to comment.