-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
7284403
commit 8023c79
Showing
11 changed files
with
800 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,3 @@ | ||
.rds | ||
.csv | ||
./data/* |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,42 @@ | ||
# flu-forecast-2022-23 | ||
Code for the paper: Forecasting influenza hospital admissions within English sub-regions using hierarchical generalised additive models | ||
# README | ||
|
||
Supporting code for the publication entitled: "Forecasting influenza hospital admissions within English sub-regions using hierarchical generalised additive models" in Nature Communications Medicine. | ||
|
||
## Data | ||
|
||
The data used for this study was live operational data from NHS England [definition given here](https://www.england.nhs.uk/long-read/process-and-definitions-for-the-daily-situation-report-web-form/). This influenza data is not in the public domain, we have therefore provided simulated data in this repository to show the models with. | ||
|
||
The data is generated assuming a national epidemic wave, represented by local units that follow the same basic trend + perturbations. | ||
|
||
|
||
## Structure | ||
|
||
The running of this code assumes the working directory is at the root folder level (ie where show_models.R sits). | ||
|
||
The main file to run to test the models and visualise them is `~/show_models.R`. | ||
|
||
This script will: | ||
|
||
1. install the package dependencies, from `~/src/depends.R` (you may want to install `here` first to set working directories) | ||
2. generate the simulated epidemic data from `~/simulate_hierarchical_epidemic.R` | ||
3. source and run the ARIMA model, from `~/run_arima.R` | ||
4. source and run the GAM model, from `~/run_gam.R` | ||
5. write out the different data to `~/data/` | ||
6. read in the data, plot it and score the models | ||
|
||
The configuration for data generation, overall modelling, and GAM parameters can be adjusted in `~/config.yaml` | ||
|
||
Note that the models have not been pre-tuned. | ||
|
||
- The bounds of available ranges for the ARIMA can be selected within `~/src/arima.R`, though the model will select sensible values itself. | ||
- The days per basis function for the GAM can be updated in the `~/config.yaml` file. | ||
|
||
The GAM model can either be run in parallel or sequentially. It runs on the scale of 1-5 minutes per fit, so adjust expectations on run time if going sequentially. The option to change this is in `~/config.yaml`. In the default setting, it requires ~12 cores to run in parallel, one core per `n_lookbacks`. | ||
|
||
|
||
## Running | ||
|
||
To test the code and different components you can: | ||
|
||
1. run `~/simulate_heirarchical_epidemic.R`, `~/run_arima.R` and `~/run_gam.R` separately to inspect each step or | ||
2. bring it all together in `~/show_models.R` which also has plotting and scoring code |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,34 @@ | ||
# configuration file for models/data | ||
|
||
# common parameters across GAM and ARIMA | ||
overall_parameters: | ||
# number of days of historic data to train models on | ||
fitting_window: 63 | ||
# how far to predict into the future | ||
horizon: 14 | ||
# how many historic projections to do | ||
n_lookbacks: 12 | ||
# number of samples from models to generate quantiles | ||
n_pi_samples: 1000 | ||
# whether to run GAM in parallel, should probably be FALSE if running locally | ||
is_parallel: TRUE | ||
|
||
# tune these parameters | ||
gam_parameters: | ||
# days per basis function of national spline | ||
national_spline: 5 | ||
# daus per basis function of sub-regional splines | ||
sub_regional_spline: 8 | ||
# type of spline, thin plate or cubic regression usually | ||
spline_type: "tp" | ||
|
||
# config for flu wave itself | ||
epidemic_data: | ||
# for simplicity number of sub regions == number of regions | ||
n_regions: 3 | ||
n_sub_regions: 3 | ||
# we don't actually need this as it's easier to work with `t`, but makes it feel | ||
# more realistic | ||
start_date: "2022-09-01" | ||
# length of the time series | ||
final_t: 150 |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,73 @@ | ||
# Script to run ARIMA | ||
|
||
# assumes working directory set to the directory this file sits in | ||
source("./src/arima.R") | ||
|
||
print("Running ARIMA") | ||
|
||
|
||
config_path <- "./config.yaml" | ||
config <- yaml::read_yaml(config_path) | ||
|
||
simulated_output <- readRDS("./data_storage/epidemic_wave.rds") | ||
|
||
# Create a time series | ||
timeseries_sub_region <- simulated_output$sub_region_data |> | ||
tsibble::as_tsibble(key = c("sub_region", "region"), index = date) | ||
|
||
timeseries_nation <- simulated_output$nation_data |> | ||
tsibble::as_tsibble(index = date) | ||
|
||
timeseries_region <- simulated_output$region_data |> | ||
tsibble::as_tsibble(key = c("region"), index = date) | ||
|
||
|
||
## Create dfs for each lookback | ||
sub_region_lookbacks <- list() | ||
region_lookbacks <- list() | ||
nation_lookbacks <- list() | ||
prediction_start_dates <- tibble::tibble(lookback = NA, prediction_start = NA) | ||
|
||
print("Setting up historic slices") | ||
for (num in 2:config$overall_parameters$n_lookbacks) { | ||
min_date <- max(timeseries_sub_region$date) - config$overall_parameters$fitting_window - (7 * num) | ||
max_date <- max(timeseries_sub_region$date) - (7 * num) | ||
name <- paste0("lookback_", num * 7) | ||
|
||
sub_region_lookbacks[[name]] <- timeseries_sub_region |> dplyr::filter(date >= min_date & date <= max_date) | ||
|
||
region_lookbacks[[name]] <- timeseries_region |> dplyr::filter(date >= min_date & date <= max_date) | ||
|
||
nation_lookbacks[[name]] <- timeseries_nation |> dplyr::filter(date >= min_date & date <= max_date) | ||
|
||
prediction_start_dates <- prediction_start_dates |> | ||
dplyr::add_row(lookback = name, prediction_start = max_date + 1) | ||
} | ||
|
||
|
||
|
||
## Run arima for different aggregations | ||
# runs number of geographical units x number of lookbacks, so may take time | ||
print("Running national") | ||
nation_arima_output <- lapply(X = nation_lookbacks, | ||
FUN = run_arima_agg, | ||
geography = "nation", | ||
true_data = simulated_output) |> | ||
dplyr::bind_rows() | ||
print("Running regional") | ||
region_arima_output <- lapply(X = region_lookbacks, | ||
FUN = run_arima_agg, | ||
geography = "region", | ||
true_data = simulated_output) |> | ||
dplyr::bind_rows() | ||
|
||
print("Running sub-regional") | ||
sub_region_arima_output <- lapply(X = sub_region_lookbacks, | ||
FUN = run_arima_agg, | ||
geography = "sub_region", | ||
true_data = simulated_output) |> | ||
dplyr::bind_rows() | ||
|
||
saveRDS(sub_region_arima_output, "./data_storage/sub_region_arima_output.rds") | ||
saveRDS(region_arima_output, "./data_storage/region_arima_output.rds") | ||
saveRDS(nation_arima_output, "./data_storage/nation_arima_output.rds") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
# Script to run GAM | ||
|
||
# assumes working directory set to the directory this file sits in | ||
source("./src/gam.R") | ||
|
||
print("Running GAM") | ||
|
||
|
||
|
||
config_path <- "./config.yaml" | ||
config <- yaml::read_yaml(config_path) | ||
|
||
simulated_output <- readRDS("./data_storage/epidemic_wave.rds") | ||
|
||
# we are only using the sub_region data (most granular) to fit | ||
flu <- simulated_output$sub_region_data | ||
|
||
|
||
# Note: we will use 28 * 3 / 7 cores | ||
max_date <- max(flu$date) | ||
preds_storage <- data.frame() | ||
|
||
factor2ref <- dplyr::distinct(flu, sub_region, region) |> | ||
dplyr::rename(factor = sub_region, factor2 = region) | ||
|
||
# create parallel clusters | ||
# WARNING - you will want to change the configuration of parallelisation | ||
# dependent on your device as this may slam your processing cores | ||
if (config$overall_parameters$is_parallel == TRUE) { | ||
cl <- parallel::makeCluster(ceiling(config$overall_parameters$fitting_window / 7), type = "FORK") | ||
doParallel::registerDoParallel(cl) | ||
`%runloop%` <- foreach::`%dopar%` | ||
|
||
} else { | ||
`%runloop%` <- foreach::`%do%` | ||
} | ||
|
||
|
||
# run the big loop of GAMs for each lookback data | ||
# 7 days space between each lookback period | ||
preds_storage <- foreach::foreach(lookback_week = seq(0, config$overall_parameters$n_lookbacks * 7, 7), | ||
.combine = "bind_rows") %runloop% { | ||
flu |> | ||
filter( | ||
date <= max(date) - lubridate::days(lookback_week), | ||
date >= (max(date) - lubridate::days(lookback_week) - config$overall_parameters$fitting_window) | ||
) -> lookback_subset | ||
|
||
preds <- run_gam_spatial( | ||
admissions = lookback_subset$flu_admissions, | ||
date = lubridate::ymd(lookback_subset$date), | ||
factor = lookback_subset$sub_region, | ||
factor2 = lookback_subset$region, | ||
population = lookback_subset$population_size, | ||
denominator = config$gam_parameters$national_spline, | ||
denominator_factor = config$gam_parameters$sub_regional_spline, | ||
spatial_nb_object = simulated_output$neighbour_list, | ||
ref = factor2ref, | ||
bs = config$gam_parameters$spline_type, | ||
horizon = config$overall_parameters$horizon, | ||
family = "nb", | ||
n_pi_samples = config$overall_parameters$n_pi_samples | ||
) | ||
} | ||
|
||
# end this cluster | ||
if (config$overall_parameters$is_parallel) stopCluster(cl) | ||
|
||
|
||
sub_region_gam_output <- preds_storage |> | ||
dplyr::filter(prediction == TRUE) |> | ||
dplyr::select(-prediction) |> | ||
dplyr::rename(region = factor2, sub_region = factor) |> | ||
dplyr::left_join( | ||
flu |> | ||
select(date, region, sub_region, flu_admissions, population_size) |> | ||
rename(true_value = flu_admissions), | ||
by = c("date", "region", "sub_region") | ||
) |> | ||
dplyr::mutate(model = "GAM", | ||
geography = "sub_region") |> | ||
tidyr::pivot_longer(dplyr::starts_with(c("pi_")), names_to = "target_type", values_to = "prediction_rate") |> | ||
dplyr::mutate(prediction = prediction_rate * population_size) |> | ||
dplyr::mutate(quantile = dplyr::case_when( | ||
stringr::str_detect(target_type, "fit") ~ 0.5, | ||
stringr::str_detect(target_type, "lower_90") ~ 0.05, | ||
stringr::str_detect(target_type, "upper_90") ~ 0.95, | ||
stringr::str_detect(target_type, "lower_95") ~ 0.025, | ||
stringr::str_detect(target_type, "upper_95") ~ 0.975, | ||
stringr::str_detect(target_type, "lower_66") ~ 0.17, | ||
stringr::str_detect(target_type, "upper_66") ~ 0.83, | ||
stringr::str_detect(target_type, "lower_50") ~ 0.25, | ||
stringr::str_detect(target_type, "upper_50") ~ 0.75, | ||
TRUE ~ NA_real_ | ||
)) |> | ||
dplyr::inner_join(preds_storage |> | ||
dplyr::group_by(start_date) |> | ||
dplyr::summarise(prediction_start_date = max(date) - 13), | ||
by = "start_date" | ||
) |> | ||
dplyr::select(prediction_start_date, date, region, sub_region, quantile, prediction, true_value, model, geography) |> | ||
dplyr::filter(prediction_start_date <= max_date - config$overall_parameters$horizon) | ||
|
||
|
||
nation_gam_output <- sub_region_gam_output |> | ||
dplyr::group_by(prediction_start_date, date, quantile, model) |> | ||
dplyr::summarise(dplyr::across(dplyr::where(is.numeric), \(x) sum(x, na.rm = T))) |> | ||
dplyr::ungroup() |> | ||
dplyr::mutate(geography = "nation") | ||
|
||
|
||
region_gam_output <- sub_region_gam_output |> | ||
dplyr::group_by(prediction_start_date, date, region, quantile, model) |> | ||
dplyr::summarise(dplyr::across(dplyr::where(is.numeric), \(x) sum(x, na.rm = T))) |> | ||
dplyr::ungroup() |> | ||
dplyr::mutate(geography = "region") | ||
|
||
saveRDS(sub_region_gam_output, "./data_storage/sub_region_gam_output.rds") | ||
saveRDS(nation_gam_output, "./data_storage/nation_gam_output.rds") | ||
saveRDS(region_gam_output, "./data_storage/region_gam_output.rds") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# script to run the models, save their output calling other scripts and plot | ||
# the projections | ||
|
||
# may need to have `here` installed already | ||
here::here() | ||
|
||
set.seed(8675309) | ||
|
||
# run the models and save their outputs | ||
source("./src/depends.R") | ||
source("./simulate_hierarchical_epidemic.R") | ||
source("./run_arima.R") | ||
source("./run_gam.R") | ||
|
||
ggplot2::theme_set(ggplot2::theme_bw()) | ||
|
||
|
||
# load predictions | ||
# ARIMA | ||
nation_arima_output <- readRDS("./data_storage/nation_arima_output.rds") | ||
region_arima_output <- readRDS("./data_storage/region_arima_output.rds") | ||
sub_region_arima_output <- readRDS("./data_storage/sub_region_arima_output.rds") | ||
|
||
# GAM | ||
nation_gam_output <- readRDS("./data_storage/nation_gam_output.rds") | ||
region_gam_output <- readRDS("./data_storage/region_gam_output.rds") | ||
sub_region_gam_output <- readRDS("./data_storage/sub_region_gam_output.rds") | ||
|
||
# combine together for ease | ||
results <- dplyr::bind_rows( | ||
nation_gam_output, | ||
region_gam_output, | ||
sub_region_gam_output, | ||
nation_arima_output, | ||
region_arima_output, | ||
sub_region_arima_output | ||
) |> | ||
dplyr::filter(as.numeric(quantile) %in% c(0.5, 0.95, 0.05, 0.25, 0.75)) | ||
|
||
|
||
# plot forecasts | ||
|
||
# sub-region | ||
results |> | ||
dplyr::filter(geography == "sub_region") |> | ||
dplyr::filter(date <= max(prediction_start_date)) |> | ||
dplyr::mutate(prediction_start_date = as.factor(prediction_start_date)) |> | ||
tidyr::pivot_wider(names_from = "quantile", values_from = "prediction") |> | ||
ggplot() + | ||
geom_point(aes(x = date, y = true_value), alpha = 0.1, size = 1) + | ||
geom_line(aes(x = date, y = `0.5`, group = prediction_start_date), linetype = 2, alpha = 0.5) + | ||
geom_ribbon(aes(x = date, ymax = `0.95`, ymin = `0.05`, group = prediction_start_date, fill = prediction_start_date), alpha = 0.2) + | ||
geom_ribbon(aes(x = date, ymax = `0.75`, ymin = `0.25`, group = prediction_start_date, fill = prediction_start_date), alpha = 0.5) + | ||
facet_grid(sub_region ~ model, scales = "free") + | ||
ylab("influenza admissions") | ||
|
||
# national | ||
# as we are aggregating uncertainty naively the prediction intervals are v wide | ||
results |> | ||
dplyr::filter(geography == "nation") |> | ||
dplyr::filter(date <= max(prediction_start_date)) |> | ||
dplyr::mutate(prediction_start_date = as.factor(prediction_start_date)) |> | ||
tidyr::pivot_wider(names_from = "quantile", values_from = "prediction") |> | ||
ggplot() + | ||
geom_point(aes(x = date, y = true_value), alpha = 0.1, size = 1) + | ||
geom_line(aes(x = date, y = `0.5`, group = prediction_start_date), linetype = 2, alpha = 0.5) + | ||
geom_ribbon(aes(x = date, ymax = `0.95`, ymin = `0.05`, group = prediction_start_date, fill = prediction_start_date), alpha = 0.2) + | ||
geom_ribbon(aes(x = date, ymax = `0.75`, ymin = `0.25`, group = prediction_start_date, fill = prediction_start_date), alpha = 0.5) + | ||
facet_wrap(~model) + | ||
ylab("influenza admissions") | ||
|
||
# region | ||
results |> | ||
dplyr::filter(geography == "region") |> | ||
dplyr::filter(date <= max(prediction_start_date)) |> | ||
dplyr::mutate(prediction_start_date = as.factor(prediction_start_date)) |> | ||
tidyr::pivot_wider(names_from = "quantile", values_from = "prediction") |> | ||
ggplot() + | ||
geom_point(aes(x = date, y = true_value), alpha = 0.1, size = 1) + | ||
geom_line(aes(x = date, y = `0.5`, group = prediction_start_date), linetype = 2, alpha = 0.5) + | ||
geom_ribbon(aes(x = date, ymax = `0.95`, ymin = `0.05`, group = prediction_start_date, fill = prediction_start_date), alpha = 0.1) + | ||
geom_ribbon(aes(x = date, ymax = `0.75`, ymin = `0.25`, group = prediction_start_date, fill = prediction_start_date), alpha = 0.3) + | ||
facet_grid(region ~ model, scales = "free") + | ||
ylab("influenza admissions") | ||
|
||
|
||
# score forecasts | ||
# example of scoring at sub_region level prediction and showing result | ||
sub_region_score <- results |> | ||
dplyr::filter(geography == "sub_region") |> | ||
dplyr::filter(date < max(prediction_start_date)) |> | ||
scoringutils::score() |> | ||
scoringutils::add_coverage(by = c("model", "prediction_start_date"), ranges = c(50, 90)) |> | ||
scoringutils::summarise_scores( | ||
by = c("model", "prediction_start_date"), | ||
na.rm = TRUE | ||
) |> | ||
scoringutils::summarise_scores(fun = signif, digits = 3) | ||
|
||
# explore scores over time | ||
sub_region_score |> | ||
ggplot() + | ||
geom_line(aes(x = prediction_start_date, y = interval_score, group = model, color = model)) + | ||
sub_region_score |> ggplot() + | ||
geom_line(aes(x = prediction_start_date, y = coverage_90, group = model, color = model)) + | ||
geom_line(aes(x = prediction_start_date, y = coverage_50, group = model, color = model), linetype = 2) + | ||
geom_hline(aes(yintercept = 0.9), alpha = 0.7) + | ||
geom_hline(aes(yintercept = 0.5), linetype = 2, alpha = 0.7) + | ||
ylim(c(0, NA)) + | ||
ylab("coverage") |
Oops, something went wrong.