Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Issue #442: Generalise to all brms distributions #459

Merged
merged 16 commits into from
Nov 21, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -25,3 +25,7 @@ data/models/*reparam
docs
/doc/
/Meta/
vignettes/**_cache/
*.pdf
.vscode/
vignettes/figures/
3 changes: 2 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ Imports:
rstan (>= 2.26.0),
dplyr,
tibble,
lubridate
lubridate,
primarycensored
Suggests:
bookdown,
testthat (>= 3.0.0),
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ export(epidist_family_prior)
export(epidist_family_reparam)
export(epidist_formula)
export(epidist_formula_model)
export(epidist_gen_posterior_epred)
export(epidist_gen_posterior_predict)
export(epidist_model_prior)
export(epidist_prior)
export(epidist_stancode)
Expand Down
73 changes: 73 additions & 0 deletions R/gen.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#' Create a function to draw from the posterior predictive distribution for a
#' latent model
#'
#' This function creates a function that draws from the posterior predictive
#' distribution for a latent model using [primarycensored::rpcens()] to handle
#' censoring and truncation. The returned function takes a `prep` argument from
#' `brms` and returns posterior predictions. This is used internally by
#' [brms::posterior_predict()] to generate predictions for latent models.
#'
#' @inheritParams epidist_family
#'
#' @return A function that takes a `prep` argument from brms and returns a
#' matrix of posterior predictions, with one row per posterior draw and one
#' column per observation. The `prep` object must have the following variables:
#' * `vreal1`: relative observation time
#' * `vreal2`: primary event window
#' * `vreal3`: secondary event window
#'
#' @seealso [brms::posterior_predict()] for details on how this is used within
#' `brms`, [primarycensored::rpcens()] for details on the censoring approach
#' @autoglobal
#' @family gen
#' @export
epidist_gen_posterior_predict <- function(family) {
seabbs marked this conversation as resolved.
Show resolved Hide resolved
dist_fn <- .get_brms_fn("posterior_predict", family)

rdist <- function(n, i, prep, ...) {
prep$ndraws <- n
do.call(dist_fn, list(i = i, prep = prep))
}

.predict <- function(i, prep, ...) {
relative_obs_time <- prep$data$vreal1[i]
seabbs marked this conversation as resolved.
Show resolved Hide resolved
pwindow <- prep$data$vreal2[i]
swindow <- prep$data$vreal3[i]

as.matrix(primarycensored::rpcens(
n = prep$ndraws,
rdist = rdist,
rprimary = stats::runif,
pwindow = prep$data$vreal2[i],
swindow = prep$data$vreal3[i],
D = prep$data$vreal1[i],
i = i,
prep = prep
))
}
return(.predict)
}

#' Create a function to draw from the expected value of the posterior predictive
#' distribution for a latent model
#'
#' This function creates a function that calculates the expected value of the
#' posterior predictive distribution for a latent model. The returned function
#' takes a `prep` argument (from brms) and returns posterior expected values.
#' This is used internally by [brms::posterior_epred()] to calculate expected
#' values for latent models.
#'
#' @inheritParams epidist_family
#'
#' @return A function that takes a prep argument from brms and returns a matrix
#' of posterior expected values, with one row per posterior draw and one column
#' per observation.
#'
#' @seealso [brms::posterior_epred()] for details on how this is used within
#' `brms`.
#' @autoglobal
#' @family gen
#' @export
epidist_gen_posterior_epred <- function(family) {
.get_brms_fn("posterior_epred", family)
}
7 changes: 0 additions & 7 deletions R/globals.R
Original file line number Diff line number Diff line change
@@ -1,16 +1,9 @@
# Generated by roxyglobals: do not edit by hand

utils::globalVariables(c(
".data", # <epidist_diagnostics>
"samples", # <epidist_diagnostics>
".data", # <as_epidist_latent_model.epidist_linelist_data>
"woverlap", # <epidist_stancode.epidist_latent_model>
".data", # <as_epidist_naive_model.epidist_linelist_data>
".data", # <add_mean_sd.lognormal_samples>
".data", # <add_mean_sd.gamma_samples>
"rlnorm", # <simulate_secondary>
".data", # <simulate_secondary>
".data", # <.replace_prior>
"prior_new", # <.replace_prior>
"source_new", # <.replace_prior>
NULL
Expand Down
84 changes: 0 additions & 84 deletions R/latent_gamma.R

This file was deleted.

87 changes: 0 additions & 87 deletions R/latent_lognormal.R

This file was deleted.

66 changes: 65 additions & 1 deletion R/latent_model.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,12 +89,76 @@ epidist_family_model.epidist_latent_model <- function(
ub = c(NA, as.numeric(lapply(family$other_bounds, "[[", "ub"))),
type = family$type,
vars = c("pwindow", "swindow", "vreal1"),
loop = FALSE
loop = FALSE,
log_lik = epidist_gen_log_lik_latent(family),
posterior_predict = epidist_gen_posterior_predict(family),
posterior_epred = epidist_gen_posterior_epred(family)
)
custom_family$reparm <- family$reparm
return(custom_family)
}

#' Create a function to calculate the pointwise log likelihood of the latent
#' model
#'
#' This function creates a log likelihood function that accounts for the latent
#' variables in the model, including primary and secondary event windows and
#' their overlap. The returned function calculates the log likelihood for a
#' single observation by augmenting the data with the latent variables and
#' using the underlying brms log likelihood function.
#'
#' @seealso [brms::log_lik()] for details on the brms log likelihood interface.
#'
#' @inheritParams epidist_family
#'
#' @return A function that calculates the log likelihood for a single
#' observation. The prep object must have the following variables:
#' * `vreal1`: relative observation time
#' * `vreal2`: primary event window
#' * `vreal3`: secondary event window
#'
#' @family latent_model
#' @autoglobal
epidist_gen_log_lik_latent <- function(family) {
# Get internal brms log_lik function
log_lik_brms <- .get_brms_fn("log_lik", family)

.log_lik <- function(i, prep) {
seabbs marked this conversation as resolved.
Show resolved Hide resolved
y <- prep$data$Y[i]
relative_obs_time <- prep$data$vreal1[i]
pwindow <- prep$data$vreal2[i]
swindow <- prep$data$vreal3[i]

# Generates values of the swindow_raw and pwindow_raw, but really these
# should be extracted from prep or the fitted raws somehow. See:
# https://github.com/epinowcast/epidist/issues/267
swindow_raw <- stats::runif(prep$ndraws)
pwindow_raw <- stats::runif(prep$ndraws)

swindow <- swindow_raw * swindow

# For no overlap calculate as usual, for overlap ensure pwindow < swindow
if (i %in% prep$data$noverlap) {
pwindow <- pwindow_raw * pwindow
} else {
pwindow <- pwindow_raw * swindow
}

d <- y - pwindow + swindow
obs_time <- relative_obs_time - pwindow
# Create brms truncation upper bound
prep$data$ub <- rep(obs_time, length(prep$data$Y))
# Update augmented data
prep$data$Y <- rep(d, length(prep$data$Y))

# Call internal brms log_lik function with augmented data
lpdf <- log_lik_brms(i, prep)
return(lpdf)
}

return(.log_lik)
}

#' Define the model-specific component of an `epidist` custom formula
#'
#' @inheritParams epidist_formula_model
Expand Down
19 changes: 19 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -157,3 +157,22 @@

return(df)
}

#' Get a brms function by prefix and family
#'
#' Helper function to get internal brms functions by constructing their name
#' from a prefix and family. Used to get functions like `log_lik_*`,
#' `posterior_predict_*` etc.
#'
#' @param prefix Character string prefix of the brms function to get (e.g.
#' "log_lik")
#'
#' @inheritParams epidist_family
#' @return The requested brms function
#' @keywords internal
.get_brms_fn <- function(prefix, family) {
get(
paste0(prefix, "_", tolower(family$family)),
asNamespace("brms")
)
}
Loading
Loading