From 4decb5c1e7b09a0f9991e9fce51db5821124464e Mon Sep 17 00:00:00 2001 From: athowes Date: Mon, 10 Jun 2024 15:49:50 +0100 Subject: [PATCH] Update vignette outline and use S3 versions --- vignettes/approx-inference.Rmd | 147 +++++++-------------------------- 1 file changed, 29 insertions(+), 118 deletions(-) diff --git a/vignettes/approx-inference.Rmd b/vignettes/approx-inference.Rmd index 6793965af..e005d162c 100644 --- a/vignettes/approx-inference.Rmd +++ b/vignettes/approx-inference.Rmd @@ -30,6 +30,30 @@ knitr::opts_chunk$set( ) ``` +# Background + +* What is the default inference method used in `epidist` (HMC) + * How does it work + * Why is it the default + * What are its strengths + * What are its drawbacks +* What are the atlernatives + * Why might you consider them + +## Laplace + +* Briefly, how does it work + +## Variational inference + +* Briefly, how does it work + +## Pathfinder + +* Briefly, how does it work + +# Demonstration + ```{r load-requirements} library(epidist) library(data.table) @@ -55,125 +79,12 @@ obs_cens_trunc <- simulate_gillespie() |> obs_cens_trunc_samp <- obs_cens_trunc[sample(seq_len(.N), sample_size, replace = FALSE)] -data <- obs_cens_trunc_samp - -formula <- brms::bf(delay_central | vreal(obs_t, pwindow_upr, swindow_upr) ~ 1, sigma ~ 1) - -fn <- brms::brm - -family <- brms::custom_family( - "latent_lognormal", - dpars = c("mu", "sigma"), - links = c("identity", "log"), - lb = c(NA, 0), - ub = c(NA, NA), - type = "real", - vars = c("pwindow", "swindow", "vreal1"), - loop = FALSE -) - -scode_functions <- " - \n real latent_lognormal_lpdf(vector y, vector mu, vector sigma, - \n vector pwindow, vector swindow, - \n array[] real obs_t) { - \n int n = num_elements(y); - \n vector[n] d = y - pwindow + swindow; - \n vector[n] obs_time = to_vector(obs_t) - pwindow; - \n return lognormal_lpdf(d | mu, sigma) - - \n lognormal_lcdf(obs_time | mu, sigma); - \n } - \n -" - -scode_parameters <- " - \n vector[N] swindow_raw; - \n vector[N] pwindow_raw; - \n -" - -scode_tparameters <- " - \n vector[N] pwindow;\n vector[N] swindow; - \n swindow = to_vector(vreal3) .* swindow_raw; - \n pwindow[noverlap] = to_vector(vreal2[noverlap]) .* pwindow_raw[noverlap]; - \n if (wN) { - \n pwindow[woverlap] = swindow[woverlap] .* pwindow_raw[woverlap]; - \n } - \n -" - -scode_priors <- " - \n swindow_raw ~ uniform(0, 1); - \n pwindow_raw ~ uniform(0, 1); - \n -" - -data <- data.table::as.data.table(data) -data[, id := 1:.N] -data[, obs_t := obs_at - ptime_lwr] - -data[, pwindow_upr := ifelse( - stime_lwr < ptime_upr, - stime_upr - ptime_lwr, - ptime_upr - ptime_lwr -)] - -data[, woverlap := as.numeric(stime_lwr < ptime_upr)] -data[, swindow_upr := stime_upr - stime_lwr] -data[, delay_central := stime_lwr - ptime_lwr] -data[, row_id := 1:.N] - -if (nrow(data) > 1) { - data <- data[, id := as.factor(id)] -} - -stanvars_functions <- brms::stanvar( - block = "functions", scode = scode_functions -) - -stanvars_data <- brms::stanvar( - block = "data", scode = "int wN;", - x = nrow(data[woverlap > 0]), - name = "wN" -) + - -brms::stanvar( - block = "data", scode = "array[N - wN] int noverlap;", - x = data[woverlap == 0][, row_id], - name = "noverlap" -) + -brms::stanvar( - block = "data", scode = "array[wN] int woverlap;", - x = data[woverlap > 0][, row_id], - name = "woverlap" -) - -stanvars_parameters <- brms::stanvar( - block = "parameters", scode = scode_parameters -) - -stanvars_tparameters <- brms::stanvar( - block = "tparameters", scode = scode_tparameters -) +data <- epidist_prepare(obs_cens_trunc_samp, model = "ltcad") -stanvars_priors <- brms::stanvar(block = "model", scode = scode_priors) - -stanvars_all <- stanvars_functions + stanvars_data + stanvars_parameters + - stanvars_tparameters + stanvars_priors - -fit_laplace <- fn( - formula = formula, family = family, stanvars = stanvars_all, - backend = "cmdstanr", data = data, algorithm = "laplace" -) - -fit_pathfinder <- fn( - formula = formula, family = family, stanvars = stanvars_all, - backend = "cmdstanr", data = data, algorithm = "pathfinder" -) - -fit_hmc <- fn( - formula = formula, family = family, stanvars = stanvars_all, - backend = "cmdstanr", data = data, algorithm = "sampling" -) +fit_hmc <- epidist(data = data, algorithm = "sampling") +fit_laplace <- epidist(data = data, algorithm = "laplace") +fit_variational <- epidist(data = data, algorithm = "meanfield") +fit_pathfinder <- epidist(data = data, algorithm = "pathfinder") ``` ## Bibliography {-}