Skip to content

Commit

Permalink
Added content of repository
Browse files Browse the repository at this point in the history
  • Loading branch information
jonathonmellor committed Dec 21, 2023
1 parent 7284403 commit 8023c79
Show file tree
Hide file tree
Showing 11 changed files with 800 additions and 2 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
.rds
.csv
./data/*
44 changes: 42 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,2 +1,42 @@
# flu-forecast-2022-23
Code for the paper: Forecasting influenza hospital admissions within English sub-regions using hierarchical generalised additive models
# README

Supporting code for the publication entitled: "Forecasting influenza hospital admissions within English sub-regions using hierarchical generalised additive models" in Nature Communications Medicine.

## Data

The data used for this study was live operational data from NHS England [definition given here](https://www.england.nhs.uk/long-read/process-and-definitions-for-the-daily-situation-report-web-form/). This influenza data is not in the public domain, we have therefore provided simulated data in this repository to show the models with.

The data is generated assuming a national epidemic wave, represented by local units that follow the same basic trend + perturbations.


## Structure

The running of this code assumes the working directory is at the root folder level (ie where show_models.R sits).

The main file to run to test the models and visualise them is `~/show_models.R`.

This script will:

1. install the package dependencies, from `~/src/depends.R` (you may want to install `here` first to set working directories)
2. generate the simulated epidemic data from `~/simulate_hierarchical_epidemic.R`
3. source and run the ARIMA model, from `~/run_arima.R`
4. source and run the GAM model, from `~/run_gam.R`
5. write out the different data to `~/data/`
6. read in the data, plot it and score the models

The configuration for data generation, overall modelling, and GAM parameters can be adjusted in `~/config.yaml`

Note that the models have not been pre-tuned.

- The bounds of available ranges for the ARIMA can be selected within `~/src/arima.R`, though the model will select sensible values itself.
- The days per basis function for the GAM can be updated in the `~/config.yaml` file.

The GAM model can either be run in parallel or sequentially. It runs on the scale of 1-5 minutes per fit, so adjust expectations on run time if going sequentially. The option to change this is in `~/config.yaml`. In the default setting, it requires ~12 cores to run in parallel, one core per `n_lookbacks`.


## Running

To test the code and different components you can:

1. run `~/simulate_heirarchical_epidemic.R`, `~/run_arima.R` and `~/run_gam.R` separately to inspect each step or
2. bring it all together in `~/show_models.R` which also has plotting and scoring code
34 changes: 34 additions & 0 deletions config.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# configuration file for models/data

# common parameters across GAM and ARIMA
overall_parameters:
# number of days of historic data to train models on
fitting_window: 63
# how far to predict into the future
horizon: 14
# how many historic projections to do
n_lookbacks: 12
# number of samples from models to generate quantiles
n_pi_samples: 1000
# whether to run GAM in parallel, should probably be FALSE if running locally
is_parallel: TRUE

# tune these parameters
gam_parameters:
# days per basis function of national spline
national_spline: 5
# daus per basis function of sub-regional splines
sub_regional_spline: 8
# type of spline, thin plate or cubic regression usually
spline_type: "tp"

# config for flu wave itself
epidemic_data:
# for simplicity number of sub regions == number of regions
n_regions: 3
n_sub_regions: 3
# we don't actually need this as it's easier to work with `t`, but makes it feel
# more realistic
start_date: "2022-09-01"
# length of the time series
final_t: 150
Empty file added data_storage/.gitkeep
Empty file.
73 changes: 73 additions & 0 deletions run_arima.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
# Script to run ARIMA

# assumes working directory set to the directory this file sits in
source("./src/arima.R")

print("Running ARIMA")


config_path <- "./config.yaml"
config <- yaml::read_yaml(config_path)

simulated_output <- readRDS("./data_storage/epidemic_wave.rds")

# Create a time series
timeseries_sub_region <- simulated_output$sub_region_data |>
tsibble::as_tsibble(key = c("sub_region", "region"), index = date)

timeseries_nation <- simulated_output$nation_data |>
tsibble::as_tsibble(index = date)

timeseries_region <- simulated_output$region_data |>
tsibble::as_tsibble(key = c("region"), index = date)


## Create dfs for each lookback
sub_region_lookbacks <- list()
region_lookbacks <- list()
nation_lookbacks <- list()
prediction_start_dates <- tibble::tibble(lookback = NA, prediction_start = NA)

print("Setting up historic slices")
for (num in 2:config$overall_parameters$n_lookbacks) {
min_date <- max(timeseries_sub_region$date) - config$overall_parameters$fitting_window - (7 * num)
max_date <- max(timeseries_sub_region$date) - (7 * num)
name <- paste0("lookback_", num * 7)

sub_region_lookbacks[[name]] <- timeseries_sub_region |> dplyr::filter(date >= min_date & date <= max_date)

region_lookbacks[[name]] <- timeseries_region |> dplyr::filter(date >= min_date & date <= max_date)

nation_lookbacks[[name]] <- timeseries_nation |> dplyr::filter(date >= min_date & date <= max_date)

prediction_start_dates <- prediction_start_dates |>
dplyr::add_row(lookback = name, prediction_start = max_date + 1)
}



## Run arima for different aggregations
# runs number of geographical units x number of lookbacks, so may take time
print("Running national")
nation_arima_output <- lapply(X = nation_lookbacks,
FUN = run_arima_agg,
geography = "nation",
true_data = simulated_output) |>
dplyr::bind_rows()
print("Running regional")
region_arima_output <- lapply(X = region_lookbacks,
FUN = run_arima_agg,
geography = "region",
true_data = simulated_output) |>
dplyr::bind_rows()

print("Running sub-regional")
sub_region_arima_output <- lapply(X = sub_region_lookbacks,
FUN = run_arima_agg,
geography = "sub_region",
true_data = simulated_output) |>
dplyr::bind_rows()

saveRDS(sub_region_arima_output, "./data_storage/sub_region_arima_output.rds")
saveRDS(region_arima_output, "./data_storage/region_arima_output.rds")
saveRDS(nation_arima_output, "./data_storage/nation_arima_output.rds")
120 changes: 120 additions & 0 deletions run_gam.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
# Script to run GAM

# assumes working directory set to the directory this file sits in
source("./src/gam.R")

print("Running GAM")



config_path <- "./config.yaml"
config <- yaml::read_yaml(config_path)

simulated_output <- readRDS("./data_storage/epidemic_wave.rds")

# we are only using the sub_region data (most granular) to fit
flu <- simulated_output$sub_region_data


# Note: we will use 28 * 3 / 7 cores
max_date <- max(flu$date)
preds_storage <- data.frame()

factor2ref <- dplyr::distinct(flu, sub_region, region) |>
dplyr::rename(factor = sub_region, factor2 = region)

# create parallel clusters
# WARNING - you will want to change the configuration of parallelisation
# dependent on your device as this may slam your processing cores
if (config$overall_parameters$is_parallel == TRUE) {
cl <- parallel::makeCluster(ceiling(config$overall_parameters$fitting_window / 7), type = "FORK")
doParallel::registerDoParallel(cl)
`%runloop%` <- foreach::`%dopar%`

} else {
`%runloop%` <- foreach::`%do%`
}


# run the big loop of GAMs for each lookback data
# 7 days space between each lookback period
preds_storage <- foreach::foreach(lookback_week = seq(0, config$overall_parameters$n_lookbacks * 7, 7),
.combine = "bind_rows") %runloop% {
flu |>
filter(
date <= max(date) - lubridate::days(lookback_week),
date >= (max(date) - lubridate::days(lookback_week) - config$overall_parameters$fitting_window)
) -> lookback_subset

preds <- run_gam_spatial(
admissions = lookback_subset$flu_admissions,
date = lubridate::ymd(lookback_subset$date),
factor = lookback_subset$sub_region,
factor2 = lookback_subset$region,
population = lookback_subset$population_size,
denominator = config$gam_parameters$national_spline,
denominator_factor = config$gam_parameters$sub_regional_spline,
spatial_nb_object = simulated_output$neighbour_list,
ref = factor2ref,
bs = config$gam_parameters$spline_type,
horizon = config$overall_parameters$horizon,
family = "nb",
n_pi_samples = config$overall_parameters$n_pi_samples
)
}

# end this cluster
if (config$overall_parameters$is_parallel) stopCluster(cl)


sub_region_gam_output <- preds_storage |>
dplyr::filter(prediction == TRUE) |>
dplyr::select(-prediction) |>
dplyr::rename(region = factor2, sub_region = factor) |>
dplyr::left_join(
flu |>
select(date, region, sub_region, flu_admissions, population_size) |>
rename(true_value = flu_admissions),
by = c("date", "region", "sub_region")
) |>
dplyr::mutate(model = "GAM",
geography = "sub_region") |>
tidyr::pivot_longer(dplyr::starts_with(c("pi_")), names_to = "target_type", values_to = "prediction_rate") |>
dplyr::mutate(prediction = prediction_rate * population_size) |>
dplyr::mutate(quantile = dplyr::case_when(
stringr::str_detect(target_type, "fit") ~ 0.5,
stringr::str_detect(target_type, "lower_90") ~ 0.05,
stringr::str_detect(target_type, "upper_90") ~ 0.95,
stringr::str_detect(target_type, "lower_95") ~ 0.025,
stringr::str_detect(target_type, "upper_95") ~ 0.975,
stringr::str_detect(target_type, "lower_66") ~ 0.17,
stringr::str_detect(target_type, "upper_66") ~ 0.83,
stringr::str_detect(target_type, "lower_50") ~ 0.25,
stringr::str_detect(target_type, "upper_50") ~ 0.75,
TRUE ~ NA_real_
)) |>
dplyr::inner_join(preds_storage |>
dplyr::group_by(start_date) |>
dplyr::summarise(prediction_start_date = max(date) - 13),
by = "start_date"
) |>
dplyr::select(prediction_start_date, date, region, sub_region, quantile, prediction, true_value, model, geography) |>
dplyr::filter(prediction_start_date <= max_date - config$overall_parameters$horizon)


nation_gam_output <- sub_region_gam_output |>
dplyr::group_by(prediction_start_date, date, quantile, model) |>
dplyr::summarise(dplyr::across(dplyr::where(is.numeric), \(x) sum(x, na.rm = T))) |>
dplyr::ungroup() |>
dplyr::mutate(geography = "nation")


region_gam_output <- sub_region_gam_output |>
dplyr::group_by(prediction_start_date, date, region, quantile, model) |>
dplyr::summarise(dplyr::across(dplyr::where(is.numeric), \(x) sum(x, na.rm = T))) |>
dplyr::ungroup() |>
dplyr::mutate(geography = "region")

saveRDS(sub_region_gam_output, "./data_storage/sub_region_gam_output.rds")
saveRDS(nation_gam_output, "./data_storage/nation_gam_output.rds")
saveRDS(region_gam_output, "./data_storage/region_gam_output.rds")
110 changes: 110 additions & 0 deletions show_models.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# script to run the models, save their output calling other scripts and plot
# the projections

# may need to have `here` installed already
here::here()

set.seed(8675309)

# run the models and save their outputs
source("./src/depends.R")
source("./simulate_hierarchical_epidemic.R")
source("./run_arima.R")
source("./run_gam.R")

ggplot2::theme_set(ggplot2::theme_bw())


# load predictions
# ARIMA
nation_arima_output <- readRDS("./data_storage/nation_arima_output.rds")
region_arima_output <- readRDS("./data_storage/region_arima_output.rds")
sub_region_arima_output <- readRDS("./data_storage/sub_region_arima_output.rds")

# GAM
nation_gam_output <- readRDS("./data_storage/nation_gam_output.rds")
region_gam_output <- readRDS("./data_storage/region_gam_output.rds")
sub_region_gam_output <- readRDS("./data_storage/sub_region_gam_output.rds")

# combine together for ease
results <- dplyr::bind_rows(
nation_gam_output,
region_gam_output,
sub_region_gam_output,
nation_arima_output,
region_arima_output,
sub_region_arima_output
) |>
dplyr::filter(as.numeric(quantile) %in% c(0.5, 0.95, 0.05, 0.25, 0.75))


# plot forecasts

# sub-region
results |>
dplyr::filter(geography == "sub_region") |>
dplyr::filter(date <= max(prediction_start_date)) |>
dplyr::mutate(prediction_start_date = as.factor(prediction_start_date)) |>
tidyr::pivot_wider(names_from = "quantile", values_from = "prediction") |>
ggplot() +
geom_point(aes(x = date, y = true_value), alpha = 0.1, size = 1) +
geom_line(aes(x = date, y = `0.5`, group = prediction_start_date), linetype = 2, alpha = 0.5) +
geom_ribbon(aes(x = date, ymax = `0.95`, ymin = `0.05`, group = prediction_start_date, fill = prediction_start_date), alpha = 0.2) +
geom_ribbon(aes(x = date, ymax = `0.75`, ymin = `0.25`, group = prediction_start_date, fill = prediction_start_date), alpha = 0.5) +
facet_grid(sub_region ~ model, scales = "free") +
ylab("influenza admissions")

# national
# as we are aggregating uncertainty naively the prediction intervals are v wide
results |>
dplyr::filter(geography == "nation") |>
dplyr::filter(date <= max(prediction_start_date)) |>
dplyr::mutate(prediction_start_date = as.factor(prediction_start_date)) |>
tidyr::pivot_wider(names_from = "quantile", values_from = "prediction") |>
ggplot() +
geom_point(aes(x = date, y = true_value), alpha = 0.1, size = 1) +
geom_line(aes(x = date, y = `0.5`, group = prediction_start_date), linetype = 2, alpha = 0.5) +
geom_ribbon(aes(x = date, ymax = `0.95`, ymin = `0.05`, group = prediction_start_date, fill = prediction_start_date), alpha = 0.2) +
geom_ribbon(aes(x = date, ymax = `0.75`, ymin = `0.25`, group = prediction_start_date, fill = prediction_start_date), alpha = 0.5) +
facet_wrap(~model) +
ylab("influenza admissions")

# region
results |>
dplyr::filter(geography == "region") |>
dplyr::filter(date <= max(prediction_start_date)) |>
dplyr::mutate(prediction_start_date = as.factor(prediction_start_date)) |>
tidyr::pivot_wider(names_from = "quantile", values_from = "prediction") |>
ggplot() +
geom_point(aes(x = date, y = true_value), alpha = 0.1, size = 1) +
geom_line(aes(x = date, y = `0.5`, group = prediction_start_date), linetype = 2, alpha = 0.5) +
geom_ribbon(aes(x = date, ymax = `0.95`, ymin = `0.05`, group = prediction_start_date, fill = prediction_start_date), alpha = 0.1) +
geom_ribbon(aes(x = date, ymax = `0.75`, ymin = `0.25`, group = prediction_start_date, fill = prediction_start_date), alpha = 0.3) +
facet_grid(region ~ model, scales = "free") +
ylab("influenza admissions")


# score forecasts
# example of scoring at sub_region level prediction and showing result
sub_region_score <- results |>
dplyr::filter(geography == "sub_region") |>
dplyr::filter(date < max(prediction_start_date)) |>
scoringutils::score() |>
scoringutils::add_coverage(by = c("model", "prediction_start_date"), ranges = c(50, 90)) |>
scoringutils::summarise_scores(
by = c("model", "prediction_start_date"),
na.rm = TRUE
) |>
scoringutils::summarise_scores(fun = signif, digits = 3)

# explore scores over time
sub_region_score |>
ggplot() +
geom_line(aes(x = prediction_start_date, y = interval_score, group = model, color = model)) +
sub_region_score |> ggplot() +
geom_line(aes(x = prediction_start_date, y = coverage_90, group = model, color = model)) +
geom_line(aes(x = prediction_start_date, y = coverage_50, group = model, color = model), linetype = 2) +
geom_hline(aes(yintercept = 0.9), alpha = 0.7) +
geom_hline(aes(yintercept = 0.5), linetype = 2, alpha = 0.7) +
ylim(c(0, NA)) +
ylab("coverage")
Loading

0 comments on commit 8023c79

Please sign in to comment.