Skip to content

Commit

Permalink
359: convert get_pit() to S3 (#910)
Browse files Browse the repository at this point in the history
* convert get_pit() to S3

* edit news item
  • Loading branch information
sbfnk authored Sep 16, 2024
1 parent 656d817 commit 1f4f36e
Show file tree
Hide file tree
Showing 5 changed files with 63 additions and 24 deletions.
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ S3method(get_metrics,forecast_point)
S3method(get_metrics,forecast_quantile)
S3method(get_metrics,forecast_sample)
S3method(get_metrics,scores)
S3method(get_pit,default)
S3method(get_pit,forecast_quantile)
S3method(get_pit,forecast_sample)
S3method(head,forecast)
S3method(print,forecast)
S3method(score,default)
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ of our [original](https://doi.org/10.48550/arXiv.2205.07090) `scoringutils` pape
- Renamed `interval_coverage_quantile()` and `interval_coverage_dev_quantile()` to `interval_coverage()` and `interval_coverage_deviation()`, respectively.
- "range" was consistently renamed to "interval_range" in the code. The "range"-format (which was mostly used internally) was renamed to "interval"-format
- Renamed `correlation()` to `get_correlations()` and `plot_correlation()` to `plot_correlations()`
- `pit()` was renamed to `get_pit()`.
- `pit()` was renamed to `get_pit()` and converted to an S3 method.

### Deleted functions
- Removed abs_error and squared_error from the package in favour of `Metrics::ae` and `Metrics::se`.`get_duplicate_forecasts()` now sorts outputs according to the forecast unit, making it easier to spot duplicates. In addition, there is a `counts` option that allows the user to display the number of duplicates for each forecast unit, rather than the raw duplicated rows.
Expand Down
64 changes: 43 additions & 21 deletions R/pit.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#' @title Probability integral transformation (sample-based version)
#' @title Probability integral transformation for counts
#'
#' @description Uses a Probability integral transformation (PIT) (or a
#' randomised PIT for integer forecasts) to
Expand Down Expand Up @@ -83,7 +83,6 @@
#' real-time epidemic forecasts: A case study of Ebola in the Western Area
#' region of Sierra Leone, 2014-15, \doi{10.1371/journal.pcbi.1006785}
#' @keywords metric

pit_sample <- function(observed,
predicted,
n_replicates = 100) {
Expand Down Expand Up @@ -143,29 +142,30 @@ pit_sample <- function(observed,
#' region of Sierra Leone, 2014-15, \doi{10.1371/journal.pcbi.1006785}
#' @keywords scoring

get_pit <- function(forecast,
by,
n_replicates = 100) {
#' @keywords scoring
#' @export
get_pit <- function(forecast, by, ...) {
UseMethod("get_pit")
}

#' @rdname get_pit
#' @importFrom cli cli_abort
#' @export
get_pit.default <- function(forecast, by, ...) {
cli_abort(c(
"!" = "The input needs to be a valid forecast object represented as quantiles or samples." # nolint
))
}

#' @rdname get_pit
#' @importFrom stats na.omit
#' @importFrom data.table `:=` as.data.table dcast
#' @inheritParams pit_sample n_replicates
#' @export
get_pit.forecast_sample <- function(forecast, by, n_replicates = 100, ...) {
forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE)
forecast_type <- get_forecast_type(forecast)
forecast <- as.data.table(forecast)

if (forecast_type == "quantile") {
forecast[, quantile_coverage := (observed <= predicted)]
quantile_coverage <-
forecast[, .(quantile_coverage = mean(quantile_coverage)),
by = c(unique(c(by, "quantile_level")))]
quantile_coverage <- quantile_coverage[order(quantile_level),
.(
quantile_level = c(quantile_level, 1),
pit_value = diff(c(0, quantile_coverage, 1))
),
by = c(get_forecast_unit(quantile_coverage))
]
return(quantile_coverage[])
}

# if prediction type is not quantile, calculate PIT values based on samples
forecast_wide <- data.table::dcast(forecast,
... ~ paste0("InternalSampl_", sample_id),
Expand All @@ -182,3 +182,25 @@ get_pit <- function(forecast,

return(pit[])
}

#' @rdname get_pit
#' @importFrom stats na.omit
#' @importFrom data.table `:=` as.data.table
#' @export
get_pit.forecast_quantile <- function(forecast, by, ...) {
forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE)
forecast <- as.data.table(forecast)

forecast[, quantile_coverage := (observed <= predicted)]
quantile_coverage <-
forecast[, .(quantile_coverage = mean(quantile_coverage)),
by = c(unique(c(by, "quantile_level")))]
quantile_coverage <- quantile_coverage[order(quantile_level),
.(
quantile_level = c(quantile_level, 1),
pit_value = diff(c(0, quantile_coverage, 1))
),
by = c(get_forecast_unit(quantile_coverage))
]
return(quantile_coverage[])
}
16 changes: 15 additions & 1 deletion man/get_pit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/pit_sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 1f4f36e

Please sign in to comment.