Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

913: Functions for PIT histograms #949

Merged
merged 15 commits into from
Oct 22, 2024
20 changes: 7 additions & 13 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ S3method(get_metrics,forecast_point)
S3method(get_metrics,forecast_quantile)
S3method(get_metrics,forecast_sample)
S3method(get_metrics,scores)
S3method(get_pit,default)
S3method(get_pit,forecast_quantile)
S3method(get_pit,forecast_sample)
S3method(get_pit_histogram,default)
S3method(get_pit_histogram,forecast_quantile)
S3method(get_pit_histogram,forecast_sample)
S3method(head,forecast)
S3method(print,forecast)
S3method(score,default)
Expand Down Expand Up @@ -56,7 +56,7 @@ export(get_forecast_counts)
export(get_forecast_unit)
export(get_metrics)
export(get_pairwise_comparisons)
export(get_pit)
export(get_pit_histogram)
export(interval_coverage)
export(is_forecast)
export(is_forecast_binary)
Expand All @@ -72,13 +72,12 @@ export(mad_sample)
export(new_forecast)
export(overprediction_quantile)
export(overprediction_sample)
export(pit_sample)
export(pit_histogram_sample)
export(plot_correlations)
export(plot_forecast_counts)
export(plot_heatmap)
export(plot_interval_coverage)
export(plot_pairwise_comparisons)
export(plot_pit)
export(plot_quantile_coverage)
export(plot_wis)
export(quantile_score)
Expand All @@ -103,6 +102,7 @@ importFrom(checkmate,assert_data_table)
importFrom(checkmate,assert_disjunct)
importFrom(checkmate,assert_factor)
importFrom(checkmate,assert_function)
importFrom(checkmate,assert_int)
importFrom(checkmate,assert_list)
importFrom(checkmate,assert_logical)
importFrom(checkmate,assert_matrix)
Expand All @@ -115,9 +115,7 @@ importFrom(checkmate,assert_vector)
importFrom(checkmate,check_atomic_vector)
importFrom(checkmate,check_function)
importFrom(checkmate,check_matrix)
importFrom(checkmate,check_number)
importFrom(checkmate,check_numeric)
importFrom(checkmate,check_set_equal)
importFrom(checkmate,check_vector)
importFrom(checkmate,test_atomic_vector)
importFrom(checkmate,test_list)
Expand All @@ -138,6 +136,7 @@ importFrom(data.table,as.data.table)
importFrom(data.table,copy)
importFrom(data.table,data.table)
importFrom(data.table,dcast)
importFrom(data.table,fcase)
importFrom(data.table,is.data.table)
importFrom(data.table,melt)
importFrom(data.table,nafill)
Expand All @@ -150,7 +149,6 @@ importFrom(data.table,setorderv)
importFrom(ggplot2,.data)
importFrom(ggplot2,`%+replace%`)
importFrom(ggplot2,aes)
importFrom(ggplot2,after_stat)
importFrom(ggplot2,coord_cartesian)
importFrom(ggplot2,coord_flip)
importFrom(ggplot2,element_blank)
Expand All @@ -159,7 +157,6 @@ importFrom(ggplot2,element_text)
importFrom(ggplot2,facet_grid)
importFrom(ggplot2,facet_wrap)
importFrom(ggplot2,geom_col)
importFrom(ggplot2,geom_histogram)
importFrom(ggplot2,geom_line)
importFrom(ggplot2,geom_linerange)
importFrom(ggplot2,geom_polygon)
Expand All @@ -175,7 +172,6 @@ importFrom(ggplot2,scale_fill_gradient)
importFrom(ggplot2,scale_fill_gradient2)
importFrom(ggplot2,scale_fill_manual)
importFrom(ggplot2,scale_y_continuous)
importFrom(ggplot2,stat)
importFrom(ggplot2,theme)
importFrom(ggplot2,theme_light)
importFrom(ggplot2,theme_minimal)
Expand All @@ -187,9 +183,7 @@ importFrom(purrr,partial)
importFrom(scoringRules,crps_sample)
importFrom(scoringRules,dss_sample)
importFrom(scoringRules,logs_sample)
importFrom(stats,as.formula)
importFrom(stats,cor)
importFrom(stats,density)
importFrom(stats,mad)
importFrom(stats,median)
importFrom(stats,na.omit)
Expand Down
3 changes: 2 additions & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ of our [original](https://doi.org/10.48550/arXiv.2205.07090) `scoringutils` pape
- Users can now also use their own scoring rules (making use of the `metrics` argument, which takes in a named list of functions). Default scoring rules can be accessed using the function `get_metrics()`, which is a a generic with S3 methods for each forecast type. It returns a named list of scoring rules suitable for the respective forecast object. For example, you could call `get_metrics(example_quantile)`. Column names of scores in the output of `score()` correspond to the names of the scoring rules (i.e. the names of the functions in the list of metrics).
- Instead of supplying arguments to `score()` to manipulate individual scoring rules users should now manipulate the metric list being supplied using `purrr::partial()` and `select_metric()`. See `?score()` for more information.
- the CRPS is now reported as decomposition into dispersion, overprediction and underprediction.
- functionality to calculate the Probability Integral Transform (PIT) has been deprecated and replaced by functionality to calculate PIT histograms, using the `get_pit_histogram()` function; as part of this change, nonrandomised PITs can now be calculated for count data, and this is is done by default

### Creating a forecast object
- The `as_forecast_<type>()` functions create a forecast object and validates it. They also allow users to rename/specify required columns and specify the forecast unit in a single step, taking over the functionality of `set_forecast_unit()` in most cases. See `?as_forecast()` for more information.
Expand Down Expand Up @@ -73,7 +74,6 @@ of our [original](https://doi.org/10.48550/arXiv.2205.07090) `scoringutils` pape
- Renamed `interval_coverage_quantile()` to `interval_coverage()`.
- "range" was consistently renamed to "interval_range" in the code. The "range"-format (which was mostly used internally) was renamed to "interval"-format
- Renamed `correlation()` to `get_correlations()` and `plot_correlation()` to `plot_correlations()`
- `pit()` was renamed to `get_pit()` and converted to an S3 method.

### Deleted functions
- Removed abs_error and squared_error from the package in favour of `Metrics::ae` and `Metrics::se`.`get_duplicate_forecasts()` now sorts outputs according to the forecast unit, making it easier to spot duplicates. In addition, there is a `counts` option that allows the user to display the number of duplicates for each forecast unit, rather than the raw duplicated rows.
Expand All @@ -84,6 +84,7 @@ of our [original](https://doi.org/10.48550/arXiv.2205.07090) `scoringutils` pape
- Removed `interval_coverage_sample()` as users are now expected to convert to a quantile format first before scoring.
- Function `set_forecast_unit()` was deleted. Instead there is now a `forecast_unit` argument in `as_forecast_<type>()` as well as in `get_duplicate_forecasts()`.
- Removed `interval_coverage_dev_quantile()`. Users can still access the difference between nominal and actual interval coverage using `get_coverage()`.
- `pit()`, `pit_sample()` and `plot_pit()` have been removed and replaced by functionality to create PIT histograms (`pit_histogram_sampel()` and `get_pit_histogram()`)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

typo: sampel -> sample


### Function changes
- `bias_quantile()` changed the way it handles forecasts where the median is missing: The median is now imputed by linear interpolation between the innermost quantiles. Previously, we imputed the median by simply taking the mean of the innermost quantiles.
Expand Down
42 changes: 36 additions & 6 deletions R/class-forecast-quantile.R
Original file line number Diff line number Diff line change
Expand Up @@ -175,27 +175,57 @@ get_metrics.forecast_quantile <- function(x, select = NULL, exclude = NULL, ...)
}


#' @rdname get_pit
#' @rdname get_pit_histogram
#' @importFrom stats na.omit
#' @importFrom data.table `:=` as.data.table
#' @export
get_pit.forecast_quantile <- function(forecast, by, ...) {
get_pit_histogram.forecast_quantile <- function(forecast, num_bins = NULL,
breaks = NULL, by, ...) {
assert_number(num_bins, lower = 1, null.ok = TRUE)
assert_numeric(breaks, lower = 0, upper = 1, null.ok = TRUE)
forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE)
forecast <- as.data.table(forecast)
present_quantiles <- unique(c(0, forecast$quantile_level, 1))
present_quantiles <- round(present_quantiles, 10)

if (!is.null(breaks)) {
quantiles <- unique(c(0, breaks, 1))
} else if (is.null(num_bins) || num_bins == "auto") {
quantiles <- present_quantiles
} else {
quantiles <- seq(0, 1, 1 / num_bins)
}
## avoid rounding errors
quantiles <- round(quantiles, 10)
sbfnk marked this conversation as resolved.
Show resolved Hide resolved
diffs <- round(diff(quantiles), 10)

if (length(setdiff(quantiles, present_quantiles)) > 0) {
cli::cli_warn(
"Some requested quantiles are missing in the forecast. ",
"The PIT histogram will be based on the quantiles present in the forecast."
)
}

forecast <- forecast[quantile_level %in% quantiles]
forecast[, quantile_coverage := (observed <= predicted)]

quantile_coverage <-
forecast[, .(quantile_coverage = mean(quantile_coverage)),
by = c(unique(c(by, "quantile_level")))]
quantile_coverage <- quantile_coverage[

bins <- sprintf("[%s,%s)", quantiles[-length(quantiles)], quantiles[-1])
mids <- (quantiles[-length(quantiles)] + quantiles[-1]) / 2

pit_histogram <- quantile_coverage[
order(quantile_level),
.(
quantile_level = c(quantile_level, 1),
pit_value = diff(c(0, quantile_coverage, 1))
density = diff(c(0, quantile_coverage, 1)) / diffs,
bin = bins,
mid = mids
),
by = c(get_forecast_unit(quantile_coverage))
]
return(quantile_coverage[])
return(pit_histogram[])
}


Expand Down
42 changes: 32 additions & 10 deletions R/class-forecast-sample.R
Original file line number Diff line number Diff line change
Expand Up @@ -165,31 +165,53 @@ get_metrics.forecast_sample <- function(x, select = NULL, exclude = NULL, ...) {
}


#' @rdname get_pit
#' @importFrom stats na.omit
#' @rdname get_pit_histogram
#' @importFrom data.table `:=` as.data.table dcast
#' @inheritParams pit_sample
#' @importFrom checkmate assert_int assert_numeric
#' @inheritParams pit_histogram_sample
#' @seealso [pit_histogram_sample()]
#' @export
get_pit.forecast_sample <- function(forecast, by, n_replicates = 100, ...) {
get_pit_histogram.forecast_sample <- function(forecast, num_bins = 10,
breaks = NULL, by, integers = c(
"nonrandom", "random", "ignore"
), n_replicates = NULL, ...) {
integers <- match.arg(integers)
assert_int(num_bins, lower = 1, null.ok = FALSE)
assert_numeric(breaks, lower = 0, upper = 1, null.ok = TRUE)
forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE)
forecast <- as.data.table(forecast)

# if prediction type is not quantile, calculate PIT values based on samples
if (is.null(breaks)) {
quantiles <- seq(0, 1, 1 / num_bins)
} else {
quantiles <- unique(c(0, breaks, 1))
}

forecast_wide <- data.table::dcast(
forecast,
... ~ paste0("InternalSampl_", sample_id),
value.var = "predicted"
)

pit <- forecast_wide[, .(pit_value = pit_sample(
observed = observed,
predicted = as.matrix(.SD)
)),
bins <- sprintf("[%s,%s)", quantiles[-length(quantiles)], quantiles[-1])
mids <- (quantiles[-length(quantiles)] + quantiles[-1]) / 2

pit_histogram <- forecast_wide[, .(
density = pit_histogram_sample(
observed = observed,
predicted = as.matrix(.SD),
quantiles = quantiles,
integers = integers,
n_replicates = n_replicates
),
bin = bins,
mid = mids
),
by = by,
.SDcols = grepl("InternalSampl_", names(forecast_wide), fixed = TRUE)
]

return(pit[])
return(pit_histogram[])
}


Expand Down
2 changes: 1 addition & 1 deletion R/get-coverage.R
Original file line number Diff line number Diff line change
Expand Up @@ -112,7 +112,7 @@ get_coverage <- function(forecast, by = "model") {
#' Default is "model".
#' @return ggplot object with a plot of interval coverage
#' @importFrom ggplot2 ggplot scale_colour_manual scale_fill_manual .data
#' facet_wrap facet_grid geom_polygon geom_line
#' facet_wrap facet_grid geom_polygon geom_line xlab ylab
#' @importFrom checkmate assert_subset
#' @importFrom data.table dcast
#' @export
Expand Down
65 changes: 65 additions & 0 deletions R/get-pit-histogram.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
#' @title Probability integral transformation histogram
#'
#' @description
#' Generate a Probability Integral Transformation (PIT) histogram for
#' validated forecast objects.
#'
#' See the examples for how to plot the result of this function.
#'
#' @inherit score params
#' @param num_bins The number of bins in the PIT histogram. For sample-based
#' forecasts, the default is 10 bins. For quantile-based forecasts, the
#' default is one bin for each available quantile.
#' You can control the number of bins by supplying a number. This is fine for
#' sample-based pit histograms, but may fail for quantile-based formats. In
#' this case it is preferred to supply explicit breaks points using the
#' `breaks` argument.
#' @param breaks Numeric vector with the break points for the bins in the
#' PIT histogram. This is preferred when creating a PIT histogram based on
#' quantile-based data. Default is `NULL` and breaks will be determined by
#' `num_bins`. If `breaks` is used, `num_bins` will be ignored.
#' 0 and 1 will always be added as left and right bounds, respectively.
#' @param by Character vector with the columns according to which the
#' PIT values shall be grouped. If you e.g. have the columns 'model' and
#' 'location' in the input data and want to have a PIT histogram for
#' every model and location, specify `by = c("model", "location")`.
#' @inheritParams pit_histogram_sample
#' @return A data.table with density values for each bin in the PIT histogram.
#' @examples
#' library("ggplot2")
#'
#' example <- as_forecast_sample(example_sample_continuous)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't it already of class sample and so this line isn't needed?

#' result <- get_pit_histogram(example, by = "model")
#' ggplot(result, aes(x = mid, y = density)) +
nikosbosse marked this conversation as resolved.
Show resolved Hide resolved
#' geom_col() +
#' facet_wrap(. ~ model) +
#' labs(x = "Quantile", "Density")
#'
#' # example with quantile data
#' example <- as_forecast_quantile(example_quantile)
#' result <- get_pit_histogram(example, by = "model")
#' ggplot(result, aes(x = mid, y = density)) +
#' geom_col() +
#' facet_wrap(. ~ model) +
#' labs(x = "Quantile", "Density")
#' @export
#' @keywords scoring
#' @references
#' Sebastian Funk, Anton Camacho, Adam J. Kucharski, Rachel Lowe,
#' Rosalind M. Eggo, W. John Edmunds (2019) Assessing the performance of
#' real-time epidemic forecasts: A case study of Ebola in the Western Area
#' region of Sierra Leone, 2014-15, \doi{10.1371/journal.pcbi.1006785}
get_pit_histogram <- function(forecast, num_bins, breaks, by,
...) {
UseMethod("get_pit_histogram")
}


#' @rdname get_pit_histogram
#' @importFrom cli cli_abort
#' @export
get_pit_histogram.default <- function(forecast, num_bins, breaks, by, ...) {
cli_abort(c(
"!" = "The input needs to be a valid forecast object represented as quantiles or samples." # nolint
))
}
Loading
Loading