Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

359: convert get_pit() to S3 #910

Merged
merged 2 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ S3method(get_metrics,forecast_point)
S3method(get_metrics,forecast_quantile)
S3method(get_metrics,forecast_sample)
S3method(get_metrics,scores)
S3method(get_pit,default)
S3method(get_pit,forecast_quantile)
S3method(get_pit,forecast_sample)
S3method(head,forecast)
S3method(print,forecast)
S3method(score,default)
Expand Down
2 changes: 1 addition & 1 deletion NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,7 @@ of our [original](https://doi.org/10.48550/arXiv.2205.07090) `scoringutils` pape
- Renamed `interval_coverage_quantile()` and `interval_coverage_dev_quantile()` to `interval_coverage()` and `interval_coverage_deviation()`, respectively.
- "range" was consistently renamed to "interval_range" in the code. The "range"-format (which was mostly used internally) was renamed to "interval"-format
- Renamed `correlation()` to `get_correlations()` and `plot_correlation()` to `plot_correlations()`
- `pit()` was renamed to `get_pit()`.
- `pit()` was renamed to `get_pit()` and converted to an S3 method.

### Deleted functions
- Removed abs_error and squared_error from the package in favour of `Metrics::ae` and `Metrics::se`.`get_duplicate_forecasts()` now sorts outputs according to the forecast unit, making it easier to spot duplicates. In addition, there is a `counts` option that allows the user to display the number of duplicates for each forecast unit, rather than the raw duplicated rows.
Expand Down
64 changes: 43 additions & 21 deletions R/pit.R
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#' @title Probability integral transformation (sample-based version)
#' @title Probability integral transformation for counts
#'
#' @description Uses a Probability integral transformation (PIT) (or a
#' randomised PIT for integer forecasts) to
Expand Down Expand Up @@ -83,7 +83,6 @@
#' real-time epidemic forecasts: A case study of Ebola in the Western Area
#' region of Sierra Leone, 2014-15, \doi{10.1371/journal.pcbi.1006785}
#' @keywords metric

pit_sample <- function(observed,
predicted,
n_replicates = 100) {
Expand Down Expand Up @@ -143,29 +142,30 @@
#' region of Sierra Leone, 2014-15, \doi{10.1371/journal.pcbi.1006785}
#' @keywords scoring

get_pit <- function(forecast,
by,
n_replicates = 100) {
#' @keywords scoring
#' @export
get_pit <- function(forecast, by, ...) {
UseMethod("get_pit")
}

#' @rdname get_pit
#' @importFrom cli cli_abort
#' @export
get_pit.default <- function(forecast, by, ...) {
cli_abort(c(
"!" = "The input needs to be a valid forecast object represented as quantiles or samples." # nolint
))
}

#' @rdname get_pit
#' @importFrom stats na.omit
#' @importFrom data.table `:=` as.data.table dcast
#' @inheritParams pit_sample n_replicates
#' @export
get_pit.forecast_sample <- function(forecast, by, n_replicates = 100, ...) {
forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE)
forecast_type <- get_forecast_type(forecast)
forecast <- as.data.table(forecast)

if (forecast_type == "quantile") {
forecast[, quantile_coverage := (observed <= predicted)]
quantile_coverage <-
forecast[, .(quantile_coverage = mean(quantile_coverage)),
by = c(unique(c(by, "quantile_level")))]
quantile_coverage <- quantile_coverage[order(quantile_level),
.(
quantile_level = c(quantile_level, 1),
pit_value = diff(c(0, quantile_coverage, 1))
),
by = c(get_forecast_unit(quantile_coverage))
]
return(quantile_coverage[])
}

# if prediction type is not quantile, calculate PIT values based on samples
forecast_wide <- data.table::dcast(forecast,
... ~ paste0("InternalSampl_", sample_id),
Expand All @@ -182,3 +182,25 @@

return(pit[])
}

#' @rdname get_pit
#' @importFrom stats na.omit
#' @importFrom data.table `:=` as.data.table
#' @export
get_pit.forecast_quantile <- function(forecast, by, ...) {
forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE)
forecast <- as.data.table(forecast)

forecast[, quantile_coverage := (observed <= predicted)]
quantile_coverage <-
forecast[, .(quantile_coverage = mean(quantile_coverage)),
by = c(unique(c(by, "quantile_level")))]

Check warning on line 197 in R/pit.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=R/pit.R,line=197,col=6,[indentation_linter] Hanging indent should be 13 spaces but is 6 spaces.
quantile_coverage <- quantile_coverage[order(quantile_level),
.(
quantile_level = c(quantile_level, 1),
pit_value = diff(c(0, quantile_coverage, 1))
),
by = c(get_forecast_unit(quantile_coverage))
]
return(quantile_coverage[])
}
16 changes: 15 additions & 1 deletion man/get_pit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/pit_sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading