Skip to content

Commit

Permalink
basic working version
Browse files Browse the repository at this point in the history
  • Loading branch information
seabbs committed Nov 29, 2024
1 parent cb32195 commit 67cd715
Show file tree
Hide file tree
Showing 5 changed files with 31 additions and 29 deletions.
2 changes: 1 addition & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ S3method(assert_epidist,epidist_marginal_model)
S3method(assert_epidist,epidist_naive_model)
S3method(epidist_family_model,default)
S3method(epidist_family_model,epidist_latent_model)
S3method(epidist_family_param,default)
S3method(epidist_family_model,epidist_marginal_model)
S3method(epidist_family_param,default)
S3method(epidist_family_prior,default)
S3method(epidist_family_prior,lognormal)
S3method(epidist_formula_model,default)
Expand Down
2 changes: 0 additions & 2 deletions R/globals.R
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
utils::globalVariables(c(
"samples", # <epidist_diagnostics>
"woverlap", # <epidist_stancode.epidist_latent_model>
"delay", # <as_epidist_marginal_model.epidist_linelist_data>
"pwindow", # <as_epidist_marginal_model.epidist_linelist_data>
"rlnorm", # <simulate_secondary>
"fix", # <.replace_prior>
"prior_new", # <.replace_prior>
Expand Down
40 changes: 21 additions & 19 deletions R/marginal_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -54,17 +54,19 @@ assert_epidist.epidist_marginal_model <- function(data, ...) {
))
assert_numeric(data$pwindow, lower = 0)
assert_numeric(data$swindow, lower = 0)
assert_numeric(data$delay_lwr)
assert_numeric(data$delay_upr)
assert_integerish(data$delay_lwr)
assert_integerish(data$delay_upr)
assert_numeric(data$relative_obs_time)
assert_true(
all(abs(data$delay_upr - (data$delay_lwr + data$swindow)) < 1e-10),
"delay_upr must equal delay_lwr + swindow"
)
assert_true(
all(data$relative_obs_time >= data$delay_upr),
"relative_obs_time must be greater than or equal to delay_upr"
)
if (!all(abs(data$delay_upr - (data$delay_lwr + data$swindow)) < 1e-10)) {
cli::cli_abort(
"delay_upr must equal delay_lwr + swindow"
)
}
if (!all(data$relative_obs_time >= data$delay_upr)) {
cli::cli_abort(
"relative_obs_time must be greater than or equal to delay_upr"
)
}
assert_numeric(data$n, lower = 1)
}

Expand Down Expand Up @@ -92,9 +94,9 @@ epidist_family_model.epidist_marginal_model <- function(
links = c(family$link, family$other_links),
lb = c(NA, as.numeric(lapply(family$other_bounds, "[[", "lb"))),
ub = c(NA, as.numeric(lapply(family$other_bounds, "[[", "ub"))),
type = family$type,
type = "int",
vars = c(
"vreal1[n]", "vreal2[n]", "vreal3[n]", "vreal4[n]"
"vreal1[n]", "vreal2[n]", "vreal3[n]", "vreal4[n]", "primary_params"
),
loop = TRUE,
log_lik = epidist_gen_log_lik(family),
Expand Down Expand Up @@ -146,7 +148,6 @@ epidist_stancode.epidist_marginal_model <- function(
fixed = TRUE
)

# Can probably be extended to non-analytic solution families but for now
if (family_name == "lognormal") {
dist_id <- 1
} else if (family_name == "gamma") {
Expand All @@ -155,7 +156,7 @@ epidist_stancode.epidist_marginal_model <- function(
dist_id <- 3
} else {
cli_abort(c(
"!" = "No analytic solution available in primarycensored for this family"
"!" = "No solution available in primarycensored for this family"
))
}

Expand All @@ -167,7 +168,7 @@ epidist_stancode.epidist_marginal_model <- function(

stanvars_functions[[1]]$scode <- gsub(
"dpars_A",
toString(paste0(vector_real, " ", family$dpars)),
toString(paste0("real ", family$dpars)),
stanvars_functions[[1]]$scode,
fixed = TRUE
)
Expand All @@ -182,17 +183,18 @@ epidist_stancode.epidist_marginal_model <- function(
fixed = TRUE
)

stanvars_functions[[1]]$scode <- gsub(
"primary_params", "", stanvars_functions[[1]]$scode,
fixed = TRUE
stanvars_parameters <- brms::stanvar(
block = "parameters",
scode = "array[0] real primary_params;"
)

pcd_stanvars_functions <- brms::stanvar(
block = "functions",
scode = primarycensored::pcd_load_stan_functions()
)

stanvars_all <- stanvars_version + stanvars_functions + pcd_stanvars_functions
stanvars_all <- stanvars_version + stanvars_functions +
pcd_stanvars_functions + stanvars_parameters

return(stanvars_all)
}
14 changes: 8 additions & 6 deletions inst/stan/marginal_model/functions.stan
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/**
* Compute the log probability density function for a marginal model with censoring
* Compute the log probability mass function for a marginal model with censoring
*
* This function is designed to be read into R where:
* - 'family' is replaced with the target distribution (e.g., 'lognormal')
Expand All @@ -15,15 +15,17 @@
* @param relative_obs_t Observation time relative to primary window start
* @param pwindow_width Primary window width (actual time scale)
* @param swindow_width Secondary window width (actual time scale)
* @param primary_params Array of parameters for primary distribution
*
* @return Log probability density with censoring adjustment for marginal model
* @return Log probability mass with censoring adjustment for marginal model
*/
real marginal_family_lpdf(data real y, dpars_A, data real y_upper,
real marginal_family_lpmf(data int y, dpars_A, data real y_upper,
data real relative_obs_t, data real pwindow_width,
data real swindow_width) {
data real swindow_width,
array[] real primary_params) {

return primarycensored_lpmf(
y | dist_id, {dpars_B}, pwindow, y_upper, relative_obs_t,
primary_id, {primary_params}
y | dist_id, {dpars_B}, pwindow_width, y_upper, relative_obs_t,
primary_id, primary_params
);
}
2 changes: 1 addition & 1 deletion vignettes/epidist.Rmd
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ linelist_data <- as_epidist_linelist_data(
obs_time = obs_cens_trunc_samp$obs_time
)
data <- as_epidist_latent_model(linelist_data)
data <- as_epidist_marginal_model(linelist_data)
class(data)
```

Expand Down

0 comments on commit 67cd715

Please sign in to comment.