Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

break up create_clean_reported_cases() #884

Merged
merged 20 commits into from
Dec 10, 2024
Merged
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
@@ -33,6 +33,7 @@ export(LogNormal)
export(NonParametric)
export(Normal)
export(R_to_growth)
export(add_breakpoints)
export(adjust_infection_to_report)
export(apply_tolerance)
export(backcalc_opts)
@@ -64,6 +65,7 @@ export(extract_inits)
export(extract_samples)
export(extract_stan_param)
export(fill_missing)
export(filter_leading_zeros)
export(fix_dist)
export(fix_parameters)
export(forecast_infections)
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -19,6 +19,10 @@
- A bug was fixed where an internal function for applying a default cdf cutoff failed due to a difference a vector length issue. By @jamesmbaazam in #858 and reviewed by @sbfnk.
- All parameters have been changed to the new parameter interface. By @sbfnk in #871 and reviewed by @seabbs.

## Package changes

- The internal functions `create_clean_reported_cases()` has been broken up into several functions, with relevant ones `filter_leading_zeros()`, `add_breakpoints()` and `apply_zero_threshold()` exposed to the user. By @sbfnk in #884 and reviewed by @seabbs.
sbfnk marked this conversation as resolved.
Show resolved Hide resolved

## Documentation

- Brought the docs on `alpha_sd` up to date with the code change from prior PR #853. By @zsusswein in #862 and reviewed by @jamesmbaazam.
62 changes: 9 additions & 53 deletions R/create.R
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
#' Create Clean Reported Cases
#' @description `r lifecycle::badge("stable")`
#' @description `r lifecycle::badge("deprecated")`
#' Filters leading zeros, completes dates, and applies an optional threshold at
#' which point 0 cases are replaced with a user supplied value (defaults to
#' `NA`).
@@ -12,16 +12,12 @@
#' number of cases based on the 7-day average. If the average is above this
#' threshold then the zero is replaced using `fill`.
#'
#' @param fill Numeric, defaults to NA. Value to use to replace NA values or
#' zeroes that are flagged because the 7-day average is above the
#' `zero_threshold`. If the default NA is used then dates with NA values or with
#' 7-day averages above the `zero_threshold` will be skipped in model fitting.
#' If this is set to 0 then the only effect is to replace NA values with 0.
#' @param fill Deprecated; zero dates with 7-day averages above the
#' `zero_threshold` will be skipped in model fitting.
#' @param add_breakpoints Logical, defaults to TRUE. Should a breakpoint column
#' be added to the data frame if it does not exist.
#'
#' @inheritParams estimate_infections
#' @importFrom data.table copy merge.data.table setorder setDT frollsum
#' @return A cleaned data frame of reported cases
#' @keywords internal
#' @examples
@@ -33,55 +29,15 @@ create_clean_reported_cases <- function(data, horizon = 0,
zero_threshold = Inf,
fill = NA_integer_,
add_breakpoints = TRUE) {
reported_cases <- data.table::setDT(data)
reported_cases_grid <- data.table::copy(reported_cases)[,
.(date = seq(min(date), max(date) + horizon, by = "days"))
]

reported_cases <- data.table::merge.data.table(
reported_cases, reported_cases_grid,
by = "date", all.y = TRUE
)

if (is.null(reported_cases$breakpoint) && add_breakpoints) {
reported_cases$breakpoint <- 0
reported_cases <- add_horizon(data, horizon = horizon)
if (add_breakpoints) {
reported_cases <- add_breakpoints(reported_cases)
}
if (!is.null(reported_cases$breakpoint)) {
reported_cases[is.na(breakpoint), breakpoint := 0]
}
reported_cases <- data.table::setorder(reported_cases, date)
## Filter out 0 reported cases from the beginning of the data
if (filter_leading_zeros) {
reported_cases <- reported_cases[order(date)][
date >= min(date[confirm[!is.na(confirm)] > 0])
]
}
# Calculate `average_7_day` which for rows with `confirm == 0`
# (the only instance where this is being used) equates to the 7-day
# right-aligned moving average at the previous data point.
reported_cases <-
reported_cases[
,
`:=`(average_7_day = (
data.table::frollsum(confirm, n = 8, na.rm = TRUE)
) / 7
)
]
# Check case counts preceding zero case counts and set to 7 day average if
# average over last 7 days is greater than a threshold
if (!is.infinite(zero_threshold)) {
reported_cases <- reported_cases[
confirm == 0 & average_7_day > zero_threshold,
confirm := NA_integer_
]
}
reported_cases[is.na(confirm), confirm := fill]
reported_cases[, "average_7_day" := NULL]
## set accumulate to FALSE in added rows
if ("accumulate" %in% colnames(reported_cases)) {
reported_cases[is.na(accumulate), accumulate := FALSE]
reported_cases <- filter_leading_zeros(reported_cases)
}
return(reported_cases)
reported_cases <- apply_zero_threshold(reported_cases, zero_threshold)
return(reported_cases[])
}

#' Create complete cases
27 changes: 27 additions & 0 deletions R/estimate_infections.R
Original file line number Diff line number Diff line change
@@ -140,6 +140,20 @@ estimate_infections <- function(data,
"estimate_infections(data)"
)
}
if (!missing(filter_leading_zeros)) {
lifecycle::deprecate_warn(
"1.7.0",
"estimate_infections(filter_leading_zeros)",
"filter_leading_zeros()"
)
}
if (!missing(zero_threshold)) {
lifecycle::deprecate_warn(
"1.7.0",
"estimate_infections(zero_threshold)",
"apply_zero_threshold()"
)
}
# Validate inputs
check_reports_valid(data, model = "estimate_infections")
assert_class(generation_time, "generation_time_opts")
@@ -184,6 +198,19 @@ estimate_infections <- function(data,
)
# Fill missing dates
reported_cases <- default_fill_missing_obs(data, obs, "confirm")
# Check initial zeros to check for deprecated filter zero functionality
if (filter_leading_zeros &&
!is.na(reported_cases[date == min(date), "confirm"]) &&
reported_cases[date == min(date), "confirm"] == 0) {
cli_warn(c(
"!" = "Filtering initial zero observations in the data. This
functionality will be removed in future versions of EpiNow2. In order
to retain the default behaviour and filter initial zero observations
use the {.fn filter_leading_zeros()} function on the data before
calling {.fn estimate_infections()}."
))
}

# Create clean and complete cases
reported_cases <- create_clean_reported_cases(
reported_cases, horizon,
24 changes: 24 additions & 0 deletions R/estimate_secondary.R
Original file line number Diff line number Diff line change
@@ -172,6 +172,20 @@ estimate_secondary <- function(data,
"estimate_secondary(data)"
)
}
if (!missing(filter_leading_zeros)) {
lifecycle::deprecate_warn(
"1.7.0",
"estimate_secondary(filter_leading_zeros)",
"filter_leading_zeros()"
)
}
if (!missing(zero_threshold)) {
lifecycle::deprecate_warn(
"1.7.0",
"estimate_secondary(zero_threshold)",
"apply_zero_threshold()"
)
}
# Validate the inputs
check_reports_valid(data, model = "estimate_secondary")
assert_class(secondary, "secondary_opts")
@@ -200,6 +214,16 @@ estimate_secondary <- function(data,

secondary_reports_dirty <-
reports[, list(date, confirm = secondary, accumulate)]
if (filter_leading_zeros &&
!is.na(secondary_reports_dirty[date == min(date), "confirm"]) &&
secondary_reports_dirty[date == min(date), "confirm"] == 0) {
cli_warn(c(
"!" = "Filtering initial zero observations in the data. This
functionality will be removed in future versions of EpiNow2. In order
to filter initial zero observations use the {.fn filter_leading_zeros()}
function on the data before calling {.fn estimate_secondary()."
))
}
secondary_reports <- create_clean_reported_cases(
secondary_reports_dirty,
filter_leading_zeros = filter_leading_zeros,
153 changes: 152 additions & 1 deletion R/preprocessing.R
Original file line number Diff line number Diff line change
@@ -42,7 +42,7 @@
##' using a data set that has multiple columns of hwich one of them
##' corresponds to observations that are to be processed here.
##' @param by Character vector. Name(s) of any additional column(s) where
##' missing data should be processed separately for each value in the column.
##' data processing should be done separately for each value in the column.
##' This is useful when using data representing e.g. multiple geographies. If
##' NULL (default) no such grouping is done.
##' @return a data.table with an `accumulate` column that indicates whether
@@ -177,3 +177,154 @@ default_fill_missing_obs <- function(data, obs, obs_column) {
}
return(data)
}

##' Add missing values for future dates
##'
##' @param accumulate The number of days to accumulate when generating posterior
##' prediction, e.g. 7 for weekly accumulated forecasts.
##' @inheritParams add_horizon
sbfnk marked this conversation as resolved.
Show resolved Hide resolved
##' @inheritParams estimate_infections
##' @importFrom data.table copy merge.data.table setDT
##' @return A data.table with missing values for future dates
##' @keywords internal
add_horizon <- function(data, horizon, accumulate = 1L,
obs_column = "confirm", by = NULL) {
assert_data_frame(data)
assert_character(obs_column)
assert_character(by, null.ok = TRUE)
assert_names(
colnames(data),
must.include = c("date", by, obs_column)
)
assert_integerish(horizon, lower = 0)
assert_integerish(accumulate, lower = 1)
assert_date(data$date, any.missing = FALSE)

reported_cases <- data.table::setDT(data)
if (horizon > 0) {
reported_cases_future <- data.table::copy(reported_cases)[,
.(date = seq(max(date) + 1, max(date) + horizon, by = "days")),
by = by
]
## if we accumulate add the column
if (accumulate > 1 || "accumulate" %in% colnames(data)) {
reported_cases_future[, accumulate := TRUE]
## set accumulation to FALSE where appropriate
if (horizon >= accumulate) {
reported_cases_future[
as.integer(date - min(date) - 1) %% accumulate == 0,
accumulate := FALSE
]
}
}
## fill any missing columns
reported_cases <- rbind(
reported_cases, reported_cases_future,
fill = TRUE
)
}
return(reported_cases[])
}

##' Add breakpoints to certain dates in a data set.
##'
##' @param dates A vector of dates to use as breakpoints.
##' @inheritParams estimate_infections
##' @return A data.table with `breakpoint` set to 1 on each of the specified
##' dates.
##' @export
##' @importFrom data.table setDT
##' @examples
##' reported_cases <- add_breakpoints(example_confirmed, as.Date("2020-03-26"))
add_breakpoints <- function(data, dates = as.Date(character(0))) {
assert_data_frame(data)
assert_names(colnames(data), must.include = "date")
assert_date(dates)
assert_date(data$date, any.missing = FALSE)
reported_cases <- data.table::setDT(data)
if (is.null(reported_cases$breakpoint)) {
reported_cases$breakpoint <- 0
}
missing_dates <- setdiff(dates, data$date)
if (length(missing_dates) > 0) {
cli_abort("Breakpoint date{?s} not found in data: {.var {missing_dates}}")
}
reported_cases[date %in% dates, breakpoint := 1]
reported_cases[is.na(breakpoint), breakpoint := 0]
return(reported_cases)
}

##' Filter leading zeros from a data set.
##'
##' @inheritParams estimate_infections
##' @inheritParams fill_missing
##' @return A data.table with leading zeros removed.
##' @export
##' @importFrom data.table setDT
##' @examples
##' cases <- data.frame(
##' date = as.Date("2020-01-01") + 0:10,
##' confirm = c(0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9)
##' )
##' filter_leading_zeros(cases)
filter_leading_zeros <- function(data, obs_column = "confirm", by = NULL) {
assert_data_frame(data)
assert_character(obs_column)
assert_character(by, null.ok = TRUE)
assert_names(
colnames(data),
must.include = c("date", by, obs_column)
)
reported_cases <- data.table::setDT(data)
reported_cases <- reported_cases[order(date)][
date >= min(date[get(obs_column)[!is.na(get(obs_column))] > 0])
]
return(reported_cases[])
}

##' Converts zero case counts to NA (missing) if the 7-day average is above a
sbfnk marked this conversation as resolved.
Show resolved Hide resolved
##' threshold.
##'
##' This function aims to detect spurious zeroes by comparing the 7-day average
##' of the case counts to a threshold. If the 7-day average is above the
##' threshold, the zero case count is replaced with NA.
sbfnk marked this conversation as resolved.
Show resolved Hide resolved
##'
##' @param threshold Numeric, defaults to Inf. Indicates if detected zero cases
sbfnk marked this conversation as resolved.
Show resolved Hide resolved
##' are meaningful by using a threshold number of cases based on the 7-day
##' average. If the average is above this threshold at the time of a zero
##' observation count then the zero is replaced with a missing (`NA`) count
##' and thus ignored in the likelihood.
##'
##' @inheritParams estimate_infections
##' @inheritParams fill_missing
##' @importFrom data.table setDT frollsum
##' @return A data.table with the zero threshold applied.
apply_zero_threshold <- function(data, threshold = Inf,
obs_column = "confirm") {
assert_data_frame(data)
assert_numeric(threshold)
reported_cases <- data.table::setDT(data)

# Calculate `average_7_day` which for rows with `confirm == 0`
# (the only instance where this is being used) equates to the 7-day
# right-aligned moving average at the previous data point.
reported_cases <-
reported_cases[
,
`:=`(average_7_day = (
data.table::frollsum(get(obs_column), n = 8, na.rm = TRUE)
) / 7
)
]
# Check case counts preceding zero case counts and set to 7 day average if
# average over last 7 days is greater than a threshold
if (!is.infinite(threshold)) {
reported_cases <- reported_cases[
get(obs_column) == 0 & average_7_day > threshold,
paste(obs_column) := NA_integer_
]
}
reported_cases[is.na(get(obs_column)), paste(obs_column) := NA_integer_]
reported_cases[, "average_7_day" := NULL]
return(reported_cases[])
}
2 changes: 1 addition & 1 deletion R/utilities.R
Original file line number Diff line number Diff line change
@@ -442,6 +442,6 @@ globalVariables(
"..lowers", "..upper_CrI", "..uppers", "timing", "dataset", "last_confirm",
"report_date", "secondary", "id", "conv", "meanlog", "primary", "scaled",
"scaling", "sdlog", "lookup", "new_draw", ".draw", "p", "distribution",
"accumulate", "..present"
"accumulate", "..present", "reported_cases"
)
)
7 changes: 7 additions & 0 deletions _pkgdown.yml
Original file line number Diff line number Diff line change
@@ -78,6 +78,13 @@ reference:
contents:
- contains("_opts")
- opts_list
- title: Preprocess data
desc: Functions used for prepropcessing data
contents:
- fill_missing
- add_breakpoints
- filter_leading_zeros
- apply_zero_threshold
- title: Summarise Across Regions
desc: Functions used for summarising across regions (designed for use with regional_epinow)
contents:
Loading
Loading