Skip to content

Commit

Permalink
Issue #338: Add double dispatch for epidist_family (#365)
Browse files Browse the repository at this point in the history
* Start working on double dispatch refactor

* Set up the double dispatch structure

* Documentation fixes

* Standardise documentation a bit

* Remove scratch

* Fix to lintr

* Start moving things into epidist_family_family

* Add family to pkgdown

* Rework of double dispatch. The family_family part is best as only a reparameterisation, the rest is the same for all families

* Move the additional dpar info into helper, correct bug with "other" name, document and lint

* Missing manual page

* Add dispatch on family into reparam

* Start thinking about tests

* Add or move tests for the family functionality

* Bug fix for dispatch on family

* Move .add_dpar_info() into epidist_family

* Improve documentation for .add_dpar_info() and move it to utils.R

* Add test for .add_dpar_info

* Fix to reparam.gamam documentation

* Move data validation into higher level and lint
  • Loading branch information
athowes authored Oct 7, 2024
1 parent cade3d0 commit 94812c2
Show file tree
Hide file tree
Showing 32 changed files with 367 additions and 164 deletions.
8 changes: 6 additions & 2 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,12 @@ S3method(add_mean_sd,gamma_samples)
S3method(add_mean_sd,lognormal_samples)
S3method(as_latent_individual,data.frame)
S3method(epidist,default)
S3method(epidist_family,default)
S3method(epidist_family,epidist_latent_individual)
S3method(epidist_family_model,default)
S3method(epidist_family_model,epidist_latent_individual)
S3method(epidist_family_prior,default)
S3method(epidist_family_prior,lognormal)
S3method(epidist_family_reparam,default)
S3method(epidist_family_reparam,gamma)
S3method(epidist_formula,default)
S3method(epidist_formula,epidist_latent_individual)
S3method(epidist_model_prior,default)
Expand All @@ -21,7 +23,9 @@ export(as_latent_individual)
export(epidist)
export(epidist_diagnostics)
export(epidist_family)
export(epidist_family_model)
export(epidist_family_prior)
export(epidist_family_reparam)
export(epidist_formula)
export(epidist_model_prior)
export(epidist_prior)
Expand Down
13 changes: 0 additions & 13 deletions R/defaults.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,6 @@ epidist_formula.default <- function(data, ...) {
)
}

#' Default method for defining a model specific family
#'
#' @inheritParams epidist_family
#' @param ... Additional arguments passed to method.
#' @family defaults
#' @export
epidist_family.default <- function(data, ...) {
cli_abort(
"No epidist_family method implemented for the class ", class(data), "\n",
"See methods(epidist_family) for available methods"
)
}

#' Default method for defining model specific Stan code
#'
#' @inheritParams epidist_stancode
Expand Down
82 changes: 82 additions & 0 deletions R/family.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
#' Define `epidist` family
#'
#' This function is used within [epidist()] to create a model specific custom
#' `brms` family object. This custom family is passed to `brms`. It is unlikely
#' that as a user you will need this function, but we export it nonetheless to
#' be transparent about what happens inside of a call to [epidist()].
#'
#' @param data A `data.frame` containing line list data
#' @param family Output of a call to `brms::brmsfamily()`
#' @param ... ...
#'
#' @family family
#' @export
epidist_family <- function(data, family = "lognormal", ...) {
epidist_validate(data)
family <- brms:::validate_family(family)
class(family) <- c(family$family, class(family))
family <- .add_dpar_info(family)
custom_family <- epidist_family_model(data, family, ...)
class(custom_family) <- c(family$family, class(custom_family))
custom_family <- epidist_family_reparam(custom_family)
return(custom_family)
}

#' The model-specific parts of an `epidist_family()` call
#'
#' @inheritParams epidist_family
#' @param family Output of a call to `brms::brmsfamily()` with additional
#' information as provided by `.add_dpar_info()`
#' @param ... Additional arguments passed to method.
#' @rdname epidist_family_model
#' @family family
#' @export
epidist_family_model <- function(data, family, ...) {
UseMethod("epidist_family_model")
}

#' Default method for defining a model specific family
#'
#' @inheritParams epidist_family_model
#' @param ... Additional arguments passed to method.
#' @family family
#' @export
epidist_family_model.default <- function(data, ...) {
cli_abort(
"No epidist_family_model method implemented for the class ", class(data),
"\n", "See methods(epidist_family_model) for available methods"
)
}

#' Reparameterise an `epidist` family to align `brms` and Stan
#'
#' @inheritParams epidist_family
#' @param ... Additional arguments passed to method.
#' @rdname epidist_family_reparam
#' @family family
#' @export
epidist_family_reparam <- function(family, ...) {
UseMethod("epidist_family_reparam")
}

#' Default method for families which do not require a reparameterisation
#'
#' @inheritParams epidist_family_reparam
#' @param ... Additional arguments passed to method.
#' @family family
#' @export
epidist_family_reparam.default <- function(family, ...) {
family$reparam <- family$dpars
return(family)
}

#' Reparameterisation for the gamma family
#'
#' @inheritParams epidist_family_reparam
#' @param ... Additional arguments passed to method.
#' @family family
#' @export
epidist_family_reparam.gamma <- function(family, ...) {
family$reparam <- c("shape", "shape ./ mu")
return(family)
}
17 changes: 1 addition & 16 deletions R/generics.R
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
#' particular `epidist` model. This may include checking the class of `data`,
#' and that it contains suitable columns.
#'
#' @param data A `data.frame` to be used for modelling.
#' @param data A `data.frame` containing line list data.
#' @family generics
#' @export
epidist_validate <- function(data) {
Expand All @@ -26,21 +26,6 @@ epidist_formula <- function(data, ...) {
UseMethod("epidist_formula")
}

#' Define model specific family
#'
#' This function is used within [epidist()] to create a model specific custom
#' `brms` family object. This object is passed to `brms`. It is unlikely that
#' as a user you will need this function, but we export it nonetheless to be
#' transparent about what exactly is happening inside of a call to [epidist()].
#'
#' @inheritParams epidist_validate
#' @param ... Additional arguments passed to method.
#' @family generics
#' @export
epidist_family <- function(data, ...) {
UseMethod("epidist_family")
}

#' Define model specific Stan code
#'
#' This function is used within [epidist()] to create any custom Stan code which
Expand Down
40 changes: 14 additions & 26 deletions R/latent_individual.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#' Prepare latent individual model
#'
#' @param data Input data to be used for modelling.
#' @param data A `data.frame` containing line list data
#' @family latent_individual
#' @export
as_latent_individual <- function(data) {
Expand Down Expand Up @@ -104,46 +104,34 @@ is_latent_individual <- function(data) {
inherits(data, "epidist_latent_individual")
}

#' Check if data has the `epidist_latent_individual` class
#'
#' @param data A `data.frame` containing line list data
#' @param family Output of a call to `brms::brmsfamily()`
#' @param ... ...
#' Create the model-specific component of an `epidist` custom family
#'
#' @method epidist_family epidist_latent_individual
#' @inheritParams epidist_family_model
#' @param ... Additional arguments passed to method.
#' @method epidist_family_model epidist_latent_individual
#' @family latent_individual
#' @export
epidist_family.epidist_latent_individual <- function(data,
family = "lognormal",
...) {
epidist_validate(data)
# allows use of stats::family and strings
family <- brms:::validate_family(family)
non_mu_links <- family[[paste0("link_", setdiff(family$dpars, "mu"))]]
non_mu_bounds <- lapply(
family$dpars[-1], brms:::dpar_bounds, family = family$family
)
epidist_family_model.epidist_latent_individual <- function(
data, family, ...
) {
# Really the name and vars are the "model-specific" parts here
custom_family <- brms::custom_family(
paste0("latent_", family$family),
dpars = family$dpars,
links = c(family$link, non_mu_links),
lb = c(NA, as.numeric(lapply(non_mu_bounds, "[[", "lb"))),
ub = c(NA, as.numeric(lapply(non_mu_bounds, "[[", "ub"))),
links = c(family$link, family$other_links),
lb = c(NA, as.numeric(lapply(family$other_bounds, "[[", "lb"))),
ub = c(NA, as.numeric(lapply(family$other_bounds, "[[", "ub"))),
type = family$type,
vars = c("pwindow", "swindow", "vreal1"),
loop = FALSE
)
reparam <- family$dpars
if (family$family == "gamma") {
reparam <- c("shape", "shape ./ mu")
}
custom_family$reparam <- reparam
custom_family$reparm <- family$reparm
return(custom_family)
}

#' Define a formula for the latent_individual model
#'
#' @param data ...
#' @param data A `data.frame` containing line list data
#' @param family The output of [epidist_family()]
#' @param formula As produced by [brms::brmsformula()]
#' @param ... ...
Expand Down
18 changes: 18 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,21 @@

return(prior)
}

#' Additional distributional parameter information for `brms` families
#'
#' Includes additional information (link functions and parameter bound) about
#' the distributional parameters of a `brms` family which are not the
#' conditional mean `mu`.
#'
#' @inheritParams epidist_family
#' @keywords internal
.add_dpar_info <- function(family) {
other_links <- family[[paste0("link_", setdiff(family$dpars, "mu"))]]
other_bounds <- lapply(
family$dpars[-1], brms:::dpar_bounds, family = family$family
)
family$other_links <- other_links
family$other_bounds <- other_bounds
return(family)
}
4 changes: 4 additions & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,10 @@ reference:
desc: Default methods for S3 generics
contents:
- has_concept("defaults")
- title: Family
desc: Functions related to specifying custom `brms` families
contents:
- has_concept("family")
- title: Priors
desc: Functions for specifying prior distributions
contents:
Expand Down
4 changes: 2 additions & 2 deletions man/as_latent_individual.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions man/dot-add_dpar_info.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions man/epidist.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 1 addition & 2 deletions man/epidist.default.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

31 changes: 17 additions & 14 deletions man/epidist_family.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

24 changes: 0 additions & 24 deletions man/epidist_family.default.Rd

This file was deleted.

Loading

0 comments on commit 94812c2

Please sign in to comment.