Skip to content

Commit

Permalink
Issue #267: Refactor to allow custom event priors and marginalise the…
Browse files Browse the repository at this point in the history
… latent likelihood (#474)

* first pass at refactoring latent model to use window formulas

* add docs to stan function

* check getting started -drive by fix plotting

* update approach to handling formulas

* get reparameterisation from brms itself vs enforcing manual declaration

* work on regexing:

* test manually setting new priors

* fix .replace_prior

* reset for pause

* add back in lower bounds

* revert pass in via formula

* add custom priors pass in

* write priors down more neatly

* add manual prior mode and optout

* clean up easy test failures

* use marginalised log likelihood

* debug marginalised likelihood

* workaround for liklihood vectorisation

* further increase prior complexity options

* update prior ordering

* catch printing issue for .replace_prior

* add news iteem

* add PR links

* speeed up test

* code read through

* clean up precommit

* turn off priorsense to check theory its numerical instability for extreme log lik values

* review comments
  • Loading branch information
seabbs authored Nov 28, 2024
1 parent 6d5b3f1 commit f9e2fc8
Show file tree
Hide file tree
Showing 43 changed files with 727 additions and 425 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ repos:
- id: check-added-large-files
args: ['--maxkb=200']
- id: end-of-file-fixer
exclude: ^tests/testthat/_snaps
- repo: local
hooks:
- id: forbid-to-commit
Expand Down
12 changes: 9 additions & 3 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@ S3method(assert_epidist,epidist_linelist_data)
S3method(assert_epidist,epidist_naive_model)
S3method(epidist_family_model,default)
S3method(epidist_family_model,epidist_latent_model)
S3method(epidist_family_param,default)
S3method(epidist_family_prior,default)
S3method(epidist_family_prior,lognormal)
S3method(epidist_family_reparam,default)
S3method(epidist_family_reparam,gamma)
S3method(epidist_formula_model,default)
S3method(epidist_formula_model,epidist_latent_model)
S3method(epidist_model_prior,default)
S3method(epidist_model_prior,epidist_latent_model)
S3method(epidist_stancode,default)
S3method(epidist_stancode,epidist_latent_model)
export(Gamma)
Expand All @@ -33,8 +33,8 @@ export(epidist)
export(epidist_diagnostics)
export(epidist_family)
export(epidist_family_model)
export(epidist_family_param)
export(epidist_family_prior)
export(epidist_family_reparam)
export(epidist_formula)
export(epidist_formula_model)
export(epidist_gen_posterior_epred)
Expand All @@ -57,16 +57,20 @@ export(simulate_secondary)
export(simulate_uniform_cases)
export(weibull)
import(ggplot2)
importFrom(brms,as.brmsprior)
importFrom(brms,bf)
importFrom(brms,lognormal)
importFrom(brms,make_stancode)
importFrom(brms,prior)
importFrom(brms,set_prior)
importFrom(brms,stanvar)
importFrom(brms,weibull)
importFrom(checkmate,assert_class)
importFrom(checkmate,assert_data_frame)
importFrom(checkmate,assert_date)
importFrom(checkmate,assert_factor)
importFrom(checkmate,assert_integer)
importFrom(checkmate,assert_integerish)
importFrom(checkmate,assert_names)
importFrom(checkmate,assert_numeric)
importFrom(checkmate,assert_true)
Expand All @@ -75,12 +79,14 @@ importFrom(cli,cli_alert_info)
importFrom(cli,cli_inform)
importFrom(cli,cli_warn)
importFrom(dplyr,bind_cols)
importFrom(dplyr,bind_rows)
importFrom(dplyr,filter)
importFrom(dplyr,full_join)
importFrom(dplyr,mutate)
importFrom(dplyr,select)
importFrom(lubridate,days)
importFrom(lubridate,is.timepoint)
importFrom(purrr,map_dbl)
importFrom(stats,Gamma)
importFrom(stats,as.formula)
importFrom(stats,setNames)
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@ Development version of `epidist`.
## Package

- Remove the default method for `epidist()`. See #473.
- Added `enforce_presence` argument to `epidist_prior()` to allow for priors to be
specified if they do not match existing parameters. See #474.
- Added a `merge` argument to `epidist_prior()` to allow for not merging user and package priors. See #474.
- Added user settable primary event priors to the latent model. See #474.
- Added a marginalised likelihood to the latent model. See #474.
- Generalised the stan reparametrisation feature to work across all distributions without manual specification by generating stan code with `brms` and then extracting the reparameterisation. See #474.

## Documentation

Expand Down
12 changes: 10 additions & 2 deletions R/epidist.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,13 @@
#' reexported as part of `epidist`.
#' @param prior One or more `brmsprior` objects created by [brms::set_prior()]
#' or related functions. These priors are passed to [epidist_prior()] in the
#' `prior` argument.
#' `prior` argument. Some models have default priors that are automatically
#' added (see [epidist_model_prior()]). These can be merged with user-provided
#' priors using the `merge_priors` argument.
#' @param merge_priors If `TRUE` then merge user priors with default priors, if
#' `FALSE` only use user priors. Defaults to `TRUE`. This may be useful if
#' the built in approaches for merging priors are not flexible enough for a
#' particular use case.
#' @param fn The internal function to be called. By default this is
#' [brms::brm()] which performs inference for the specified model. Other options
#' are [brms::make_stancode()] which returns the Stan code for the specified
Expand All @@ -25,14 +31,16 @@
#' @export
epidist <- function(data, formula = mu ~ 1,
family = lognormal(), prior = NULL,
merge_priors = TRUE,
fn = brms::brm, ...) {
assert_epidist(data)
epidist_family <- epidist_family(data, family)
epidist_formula <- epidist_formula(
data = data, family = epidist_family, formula = formula
)
epidist_prior <- epidist_prior(
data = data, family = epidist_family, formula = epidist_formula, prior
data = data, family = epidist_family, formula = epidist_formula, prior,
merge = merge_priors
)
epidist_stancode <- epidist_stancode(
data = data, family = epidist_family, formula = epidist_formula
Expand Down
76 changes: 60 additions & 16 deletions R/family.R
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ epidist_family <- function(data, family = lognormal(), ...) {
family <- .add_dpar_info(family)
custom_family <- epidist_family_model(data, family, ...)
class(custom_family) <- c(family$family, class(custom_family))
custom_family <- epidist_family_reparam(custom_family)
custom_family <- epidist_family_param(custom_family)
return(custom_family)
}

Expand Down Expand Up @@ -43,29 +43,73 @@ epidist_family_model.default <- function(data, family, ...) {
#' Reparameterise an `epidist` family to align `brms` and Stan
#'
#' @inheritParams epidist_family
#' @rdname epidist_family_reparam
#' @rdname epidist_family_param
#' @family family
#' @export
epidist_family_reparam <- function(family, ...) {
UseMethod("epidist_family_reparam")
epidist_family_param <- function(family, ...) {
UseMethod("epidist_family_param")
}

#' Default method for families which do not require a reparameterisation
#'
#' @inheritParams epidist_family
#' @family family
#' @export
epidist_family_reparam.default <- function(family, ...) {
family$reparam <- family$dpars
return(family)
}

#' Reparameterisation for the gamma family
#' This function extracts the Stan parameterisation for a given brms family by
#' creating a dummy model and parsing its Stan code. It looks for the log
#' probability density function (lpdf) call in the Stan code and extracts the
#' parameter order used by Stan. This is needed because brms and Stan may use
#' different parameter orderings for the same distribution.
#'
#' @param family A brms family object containing at minimum a `family` element
#' specifying the distribution family name
#' @param ... Additional arguments passed to methods (not used)
#'
#' @details
#' The function works by:
#' 1. Creating a minimal dummy model using the specified family
#' 2. Extracting the Stan code for this model
#' 3. Finding the lpdf function call for the family
#' 4. Parsing out the parameter ordering used in Stan
#' 5. Adding this as the `param` element to the family object
#'
#' @return The input family object with an additional `param` element containing
#' the Stan parameter ordering as a string
#'
#' @inheritParams epidist_family
#' @family family
#' @importFrom brms make_stancode
#' @importFrom cli cli_abort
#' @export
epidist_family_reparam.gamma <- function(family, ...) {
family$reparam <- c("shape", "shape ./ mu") # nolint
epidist_family_param.default <- function(family, ...) {
df <- data.frame(y = c(1, 2))
dummy_mdl <- make_stancode(y ~ 1, data = df, family = class(family)[1])

# get the lowered family name
family_name <- tolower(class(family)[1])

# Extract the Stan parameterisation from the dummy model code
lpdf_pattern <- paste0(
"target \\+= ", family_name, "_(lpdf|lpmf)\\(Y \\| ([^)]+)\\)" # nolint
)
lpdf_match <- regexpr(lpdf_pattern, dummy_mdl)
reparam <- if (lpdf_match > 0) {
matches <- unlist(regmatches(dummy_mdl, lpdf_match))
mu_matches <- matches[grepl("mu", matches, fixed = TRUE)]
if (length(mu_matches) > 1) {
cli_abort("Multiple Stan parameterisations found with 'mu' parameter.")
} else if (length(mu_matches) == 0) {
cli_abort("No Stan parameterisation found with 'mu' parameter.")
}
match_str <- mu_matches[1]
param <- sub(
paste0(
"target \\+= ", family_name, "_(lpdf|lpmf)\\(Y \\| " # nolint
), "",
match_str
)
param <- sub(")", "", param, fixed = TRUE)
family$param <- param
} else {
cli_abort(
"Unable to extract Stan parameterisation for {family_name}."
)
}
return(family)
}
74 changes: 72 additions & 2 deletions R/gen.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,75 @@
#' Create a function to calculate the marginalised log likelihood for double
#' censored and truncated delay distributions
#'
#' This function creates a log likelihood function that calculates the marginal
#' likelihood for a single observation by integrating over the latent primary
#' and secondary event windows. The integration is performed numerically using
#' [primarycensored::dpcens()] which handles the double censoring and truncation
#' of the delay distribution.
#'
#' The marginal likelihood accounts for uncertainty in both the primary and
#' secondary event windows by integrating over their possible values, weighted
#' by their respective uniform distributions.
#'
#' @seealso [brms::log_lik()] for details on the brms log likelihood interface.
#'
#' @inheritParams epidist_family
#'
#' @return A function that calculates the marginal log likelihood for a single
#' observation. The prep object must have the following variables:
#' * `vreal1`: relative observation time
#' * `vreal2`: primary event window
#' * `vreal3`: secondary event window
#'
#' @family gen
#' @autoglobal
#' @importFrom purrr map_dbl
epidist_gen_log_lik <- function(family) {
# Get internal brms log_lik function
log_lik_brms <- .get_brms_fn("log_lik", family)

.log_lik <- function(i, prep) {
y <- prep$data$Y[i]
relative_obs_time <- prep$data$vreal1[i]
pwindow <- prep$data$vreal2[i]
swindow <- prep$data$vreal3[i]

# make the prep object censored
# -1 here is equivalent to right censored in brms
prep$data$cens <- -1

# Calculate density for each draw using primarycensored::dpcens()
lpdf <- purrr::map_dbl(seq_len(prep$ndraws), function(draw) {
# Define pdist function that filters to current draw
pdist_draw <- function(q, i, prep, ...) {
purrr::map_dbl(q, function(x) {
prep$data$Y <- rep(x, length(prep$data$Y))
ll <- exp(log_lik_brms(i, prep)[draw])
return(ll)
})
}

primarycensored::dpcens(
x = y,
pdist = pdist_draw,
i = i,
prep = prep,
pwindow = pwindow,
swindow = swindow,
D = relative_obs_time,
dprimary = stats::dunif,
log = TRUE
)
})
lpdf <- brms:::log_lik_weight(lpdf, i = i, prep = prep) # nolint
return(lpdf)
}

return(.log_lik)
}

#' Create a function to draw from the posterior predictive distribution for a
#' latent model
#' double censored and truncated delay distribution
#'
#' This function creates a function that draws from the posterior predictive
#' distribution for a latent model using [primarycensored::rpcens()] to handle
Expand Down Expand Up @@ -49,7 +119,7 @@ epidist_gen_posterior_predict <- function(family) {
}

#' Create a function to draw from the expected value of the posterior predictive
#' distribution for a latent model
#' distribution for a model
#'
#' This function creates a function that calculates the expected value of the
#' posterior predictive distribution for a latent model. The returned function
Expand Down
1 change: 1 addition & 0 deletions R/globals.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ utils::globalVariables(c(
"samples", # <epidist_diagnostics>
"woverlap", # <epidist_stancode.epidist_latent_model>
"rlnorm", # <simulate_secondary>
"fix", # <.replace_prior>
"prior_new", # <.replace_prior>
"source_new", # <.replace_prior>
NULL
Expand Down
Loading

0 comments on commit f9e2fc8

Please sign in to comment.