diff --git a/R/check-input-helpers.R b/R/check-input-helpers.R index ea16b0c18..fca2a73b5 100644 --- a/R/check-input-helpers.R +++ b/R/check-input-helpers.R @@ -1,3 +1,22 @@ +#' Ensure that an object is a `data.table` +#' @description +#' This function ensures that an object is a `data table`. +#' If the object is not a data table, it is converted to one. If the object +#' is a data table, a copy of the object is returned. +#' @param data An object to ensure is a data table. +#' @return A data.table/a copy of an existing data.table. +#' @keywords internal +#' @importFrom data.table copy is.data.table as.data.table +ensure_data.table <- function(data) { + if (is.data.table(data)) { + data <- copy(data) + } else { + data <- as.data.table(data) + } + return(data) +} + + #' @title Check whether an input is an atomic vector of mode 'numeric' #' #' @description Helper function to check whether an input is a numeric vector. @@ -20,6 +39,9 @@ check_numeric_vector <- function(x, ...) { } +# ============================================================================== +# functinos below will be deleted in the future + #' @title Helper function to convert assert statements into checks #' #' @description @@ -40,62 +62,6 @@ check_try <- function(expr) { return(msg) } -#' Check that all forecasts have the same number of quantiles or samples -#' @description -#' Function checks the number of quantiles or samples per forecast. -#' If the number of quantiles or samples is the same for all forecasts, it -#' returns TRUE and a string with an error message otherwise. -#' @param forecast_unit Character vector denoting the unit of a single forecast. -#' @importFrom checkmate assert_subset -#' @inherit document_check_functions params return -#' @keywords internal_input_check -check_number_per_forecast <- function(data, forecast_unit) { - # This function doesn't return a forecast object so it's fine to unclass it - # to avoid validation error while subsetting - data <- as.data.table(data) - data <- na.omit(data) - # check whether there are the same number of quantiles, samples -------------- - data[, scoringutils_InternalNumCheck := length(predicted), by = forecast_unit] - n <- unique(data$scoringutils_InternalNumCheck) - data[, scoringutils_InternalNumCheck := NULL] - if (length(n) > 1) { - msg <- paste0( - "Some forecasts have different numbers of rows ", - "(e.g. quantiles or samples). ", - "scoringutils found: ", toString(n), - ". This may be a problem (it can potentially distort scores, ", - "making it more difficult to compare them), ", - "so make sure this is intended." - ) - return(msg) - } - return(TRUE) -} - - -#' Check that there are no duplicate forecasts -#' -#' @description -#' Runs [get_duplicate_forecasts()] and returns a message if an issue is -#' encountered -#' @inheritParams get_duplicate_forecasts -#' @inherit document_check_functions return -#' @keywords internal_input_check -check_duplicates <- function(data) { - check_duplicates <- get_duplicate_forecasts(data) - - if (nrow(check_duplicates) > 0) { - msg <- paste0( - "There are instances with more than one forecast for the same target. ", - "This can't be right and needs to be resolved. Maybe you need to ", - "check the unit of a single forecast and add missing columns? Use ", - "the function get_duplicate_forecasts() to identify duplicate rows" - ) - return(msg) - } - return(TRUE) -} - #' Check column names are present in a data.frame #' @description diff --git a/R/check-inputs-scoring-functions.R b/R/check-inputs-scoring-functions.R deleted file mode 100644 index 17578d458..000000000 --- a/R/check-inputs-scoring-functions.R +++ /dev/null @@ -1,332 +0,0 @@ -#' @title Assert that inputs are correct for sample-based forecast -#' @description -#' Function assesses whether the inputs correspond to the requirements for -#' scoring sample-based forecasts. -#' @param predicted Input to be checked. Should be a numeric nxN matrix of -#' predictive samples, n (number of rows) being the number of data points and -#' N (number of columns) the number of samples per forecast. -#' If `observed` is just a single number, then predicted values can just be a -#' vector of size N. -#' @importFrom checkmate assert assert_numeric check_matrix assert_matrix -#' @inherit document_assert_functions params return -#' @keywords internal_input_check -assert_input_sample <- function(observed, predicted) { - assert_numeric(observed, min.len = 1) - n_obs <- length(observed) - - if (n_obs == 1) { - assert( - # allow one of two options - check_numeric_vector(predicted, min.len = 1), - check_matrix(predicted, mode = "numeric", nrows = n_obs) - ) - } else { - assert_matrix(predicted, mode = "numeric", nrows = n_obs) - } - return(invisible(NULL)) -} - -#' @title Check that inputs are correct for sample-based forecast -#' @inherit assert_input_sample params description -#' @inherit document_check_functions return -#' @keywords internal_input_check -check_input_sample <- function(observed, predicted) { - result <- check_try(assert_input_sample(observed, predicted)) - return(result) -} - - -#' @title Assert that inputs are correct for quantile-based forecast -#' @description -#' Function assesses whether the inputs correspond to the -#' requirements for scoring quantile-based forecasts. -#' @param predicted Input to be checked. Should be nxN matrix of predictive -#' quantiles, n (number of rows) being the number of data points and N -#' (number of columns) the number of quantiles per forecast. -#' If `observed` is just a single number, then predicted can just be a -#' vector of size N. -#' @param quantile_level Input to be checked. Should be a vector of size N that -#' denotes the quantile levels corresponding to the columns of the prediction -#' matrix. -#' @param unique_quantile_levels Whether the quantile levels are required to be -#' unique (`TRUE`, the default) or not (`FALSE`). -#' @importFrom checkmate assert assert_numeric check_matrix check_vector -#' @inherit document_assert_functions params return -#' @keywords internal_input_check -assert_input_quantile <- function(observed, predicted, quantile_level, - unique_quantile_levels = TRUE) { - assert_numeric(observed, min.len = 1) - n_obs <- length(observed) - - assert_numeric( - quantile_level, min.len = 1, lower = 0, upper = 1, - unique = unique_quantile_levels - ) - n_quantiles <- length(quantile_level) - if (n_obs == 1) { - assert( - # allow one of two options - check_numeric_vector(predicted, min.len = n_quantiles), - check_matrix(predicted, mode = "numeric", - nrows = n_obs, ncols = n_quantiles) - ) - assert(check_vector(quantile_level, len = length(predicted))) - } else { - assert( - check_matrix(predicted, mode = "numeric", - nrows = n_obs, ncols = n_quantiles) - ) - } - return(invisible(NULL)) -} - -#' @title Check that inputs are correct for quantile-based forecast -#' @inherit assert_input_quantile params description -#' @inherit check_input_sample return description -#' @keywords internal_input_check -check_input_quantile <- function(observed, predicted, quantile_level) { - result <- check_try( - assert_input_quantile(observed, predicted, quantile_level) - ) - return(result) -} - - -#' @title Assert that inputs are correct for interval-based forecast -#' @description -#' Function assesses whether the inputs correspond to the -#' requirements for scoring interval-based forecasts. -#' @param lower Input to be checked. Should be a numeric vector of size n that -#' holds the predicted value for the lower bounds of the prediction intervals. -#' @param upper Input to be checked. Should be a numeric vector of size n that -#' holds the predicted value for the upper bounds of the prediction intervals. -#' @param interval_range Input to be checked. Should be a vector of size n that -#' denotes the interval range in percent. E.g. a value of 50 denotes a -#' (25%, 75%) prediction interval. -#' @importFrom cli cli_warn cli_abort -#' @inherit document_assert_functions params return -#' @keywords internal_input_check -assert_input_interval <- function(observed, lower, upper, interval_range) { - - assert(check_numeric_vector(observed, min.len = 1)) - n <- length(observed) - assert(check_numeric_vector(lower, len = n)) - assert(check_numeric_vector(upper, len = n)) - assert( - check_numeric_vector(interval_range, len = 1, lower = 0, upper = 100), - check_numeric_vector(interval_range, len = n, lower = 0, upper = 100) - ) - - diff <- upper - lower - diff <- diff[!is.na(diff)] - if (any(diff < 0)) { - cli_abort( - c( - "!" = "All values in `upper` need to be greater than or equal to - the corresponding values in `lower`" - ) - ) - } - if (any(interval_range > 0 & interval_range < 1, na.rm = TRUE)) { - #nolint start: keyword_quote_linter - cli_warn( - c( - "!" = "Found interval ranges between 0 and 1. Are you sure that's - right? An interval range of 0.5 e.g. implies a (49.75%, 50.25%) - prediction interval.", - "i" = "If you want to score a (25%, 75%) prediction interval, set - `interval_range = 50`." - ), - .frequency = "once", - .frequency_id = "small_interval_range" - ) - #nolint end - } - return(invisible(NULL)) -} - - -#' @title Check that inputs are correct for interval-based forecast -#' @inherit assert_input_interval params description -#' @inherit check_input_sample return description -#' @keywords internal_input_check -check_input_interval <- function(observed, lower, upper, interval_range) { - result <- check_try( - assert_input_interval(observed, lower, upper, interval_range) - ) - return(result) -} - - -#' @title Assert that inputs are correct for binary forecast -#' @description -#' Function assesses whether the inputs correspond to the -#' requirements for scoring binary forecasts. -#' @param observed Input to be checked. Should be a factor of length n with -#' exactly two levels, holding the observed values. -#' The highest factor level is assumed to be the reference level. This means -#' that `predicted` represents the probability that the observed value is -#' equal to the highest factor level. -#' @param predicted Input to be checked. `predicted` should be a vector of -#' length n, holding probabilities. Alternatively, `predicted` can be a matrix -#' of size n x 1. Values represent the probability that -#' the corresponding value in `observed` will be equal to the highest -#' available factor level. -#' @importFrom checkmate assert assert_factor -#' @inherit document_assert_functions return -#' @keywords internal_input_check -assert_input_binary <- function(observed, predicted) { - assert_factor(observed, n.levels = 2, min.len = 1) - assert_numeric(predicted, lower = 0, upper = 1) - assert_dims_ok_point(observed, predicted) - return(invisible(NULL)) -} - - -#' @title Check that inputs are correct for binary forecast -#' @inherit assert_input_binary params description -#' @inherit document_check_functions return -#' @keywords internal_input_check -check_input_binary <- function(observed, predicted) { - result <- check_try(assert_input_binary(observed, predicted)) - return(result) -} - - -#' @title Assert that inputs are correct for nominal forecasts -#' @description Function assesses whether the inputs correspond to the -#' requirements for scoring nominal forecasts. -#' @param observed Input to be checked. Should be a factor of length n with -#' N levels holding the observed values. n is the number of observations and -#' N is the number of possible outcomes the observed values can assume. -#' output) -#' @param predicted Input to be checked. Should be nxN matrix of predictive -#' quantiles, n (number of rows) being the number of data points and N -#' (number of columns) the number of possible outcomes the observed values -#' can assume. -#' If `observed` is just a single number, then predicted can just be a -#' vector of size N. -#' @param predicted Input to be checked. `predicted` should be a vector of -#' length n, holding probabilities. Alternatively, `predicted` can be a matrix -#' of size n x 1. Values represent the probability that -#' the corresponding value in `observed` will be equal to the highest -#' available factor level. -#' @param predicted_label Factor of length N with N levels, where N is the -#' number of possible outcomes the observed values can assume. -#' @importFrom checkmate assert_factor assert_numeric assert_set_equal -#' @inherit document_assert_functions return -#' @keywords internal_input_check -assert_input_nominal <- function(observed, predicted, predicted_label) { - # observed - assert_factor(observed, min.len = 1, min.levels = 2) - levels <- levels(observed) - n <- length(observed) - N <- length(levels) - - # predicted label - assert_factor( - predicted_label, len = N, - any.missing = FALSE, empty.levels.ok = FALSE - ) - assert_set_equal(levels(observed), levels(predicted_label)) - - # predicted - assert_numeric(predicted, min.len = 1, lower = 0, upper = 1) - if (n == 1) { - assert( - # allow one of two options - check_vector(predicted, len = N), - check_matrix(predicted, nrows = n, ncols = N) - ) - summed_predictions <- .rowSums(predicted, m = 1, n = N, na.rm = TRUE) - } else { - assert_matrix(predicted, nrows = n) - summed_predictions <- round(rowSums(predicted, na.rm = TRUE), 10) # avoid numeric errors - } - if (!all(summed_predictions == 1)) { - #nolint start: keyword_quote_linter object_usage_linter - row_indices <- as.character(which(summed_predictions != 1)) - cli_abort( - c( - `!` = "Probabilities belonging to a single forecast must sum to one", - `i` = "Found issues in row{?s} {row_indices} of {.var predicted}" - ) - ) - #nolint end - } - return(invisible(NULL)) -} - - -#' @title Assert that inputs are correct for point forecast -#' @description -#' Function assesses whether the inputs correspond to the -#' requirements for scoring point forecasts. -#' @param predicted Input to be checked. Should be a numeric vector with the -#' predicted values of size n. -#' @inherit document_assert_functions params return -#' @keywords internal_input_check -assert_input_point <- function(observed, predicted) { - assert(check_numeric(observed)) - assert(check_numeric(predicted)) - assert(check_dims_ok_point(observed, predicted)) - return(invisible(NULL)) -} - -#' @title Check that inputs are correct for point forecast -#' @inherit assert_input_point params description -#' @inherit document_check_functions return -#' @keywords internal_input_check -check_input_point <- function(observed, predicted) { - result <- check_try(assert_input_point(observed, predicted)) - return(result) -} - - -#' @title Assert Inputs Have Matching Dimensions -#' @description -#' Function assesses whether input dimensions match. In the -#' following, n is the number of observations / forecasts. Scalar values may -#' be repeated to match the length of the other input. -#' Allowed options are therefore: -#' - `observed` is vector of length 1 or length n -#' - `predicted` is: -#' - a vector of of length 1 or length n -#' - a matrix with n rows and 1 column -#' @inherit assert_input_binary -#' @inherit document_assert_functions return -#' @importFrom checkmate assert_vector check_matrix check_vector assert -#' @importFrom cli cli_abort -#' @keywords internal_input_check -assert_dims_ok_point <- function(observed, predicted) { - assert_vector(observed, min.len = 1) - n_obs <- length(observed) - assert( - check_vector(predicted, min.len = 1, strict = TRUE), - check_matrix(predicted, ncols = 1, nrows = n_obs) - ) - n_pred <- length(as.vector(predicted)) - # check that both are either of length 1 or of equal length - if ((n_obs != 1) && (n_pred != 1) && (n_obs != n_pred)) { - #nolint start: keyword_quote_linter object_usage_linter - cli_abort( - c( - "!" = "`observed` and `predicted` must either be of length 1 or - of equal length.", - "i" = "Found {n_obs} and {n_pred}." - ) - ) - #nolint end - } - return(invisible(NULL)) -} - - -#' @title Check Inputs Have Matching Dimensions -#' @inherit assert_dims_ok_point params description -#' @inherit document_check_functions return -#' @keywords internal_input_check -check_dims_ok_point <- function(observed, predicted) { - result <- check_try(assert_dims_ok_point(observed, predicted)) - return(result) -} diff --git a/R/class-forecast-binary.R b/R/class-forecast-binary.R new file mode 100644 index 000000000..d76057828 --- /dev/null +++ b/R/class-forecast-binary.R @@ -0,0 +1,147 @@ +#' @title Create a `forecast` object for binary forecasts +#' @description +#' Create a `forecast` object for binary forecasts. See more information on +#' forecast types and expected input formats by calling `?`[as_forecast()]. +#' @export +#' @inheritParams as_forecast +#' @family functions to create forecast objects +#' @importFrom cli cli_warn +#' @keywords as_forecast +as_forecast_binary <- function(data, + forecast_unit = NULL, + observed = NULL, + predicted = NULL) { + data <- as_forecast_generic(data, forecast_unit, observed, predicted) + data <- new_forecast(data, "forecast_binary") + assert_forecast(data) + return(data) +} + + +#' @export +#' @rdname assert_forecast +#' @importFrom cli cli_abort +#' @keywords validate-forecast-object +assert_forecast.forecast_binary <- function( + forecast, forecast_type = NULL, verbose = TRUE, ... +) { + forecast <- assert_forecast_generic(forecast, verbose) + assert_forecast_type(forecast, actual = "binary", desired = forecast_type) + + columns_correct <- test_columns_not_present( + forecast, c("sample_id", "quantile_level") + ) + if (!columns_correct) { + #nolint start: keyword_quote_linter + cli_abort( + c( + "!" = "Checking `forecast`: Input looks like a binary forecast, but an + additional column called `sample_id` or `quantile` was found.", + "i" = "Please remove the column." + ) + ) + } + input_check <- check_input_binary(forecast$observed, forecast$predicted) + if (!isTRUE(input_check)) { + cli_abort( + c( + "!" = "Checking `forecast`: Input looks like a binary forecast, but + found the following issue: {input_check}" + ) + ) + #nolint end + } + return(invisible(NULL)) +} + + +#' @export +#' @rdname is_forecast +is_forecast_binary <- function(x) { + inherits(x, "forecast_binary") && inherits(x, "forecast") +} + + +#' @importFrom stats na.omit +#' @importFrom data.table setattr copy +#' @rdname score +#' @export +score.forecast_binary <- function(forecast, metrics = get_metrics(forecast), ...) { + forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE) + metrics <- validate_metrics(metrics) + forecast <- as.data.table(forecast) + + scores <- apply_metrics( + forecast, metrics, + forecast$observed, forecast$predicted + ) + scores[, `:=`(predicted = NULL, observed = NULL)] + + scores <- as_scores(scores, metrics = names(metrics)) + return(scores[]) +} + + +#' Get default metrics for binary forecasts +#' +#' @description +#' For binary forecasts, the default scoring rules are: +#' - "brier_score" = [brier_score()] +#' - "log_score" = [logs_binary()] +#' @inheritSection illustration-input-metric-binary-point Input format +#' @param x A forecast object (a validated data.table with predicted and +#' observed values, see [as_forecast()]). +#' @param select A character vector of scoring rules to select from the list. If +#' `select` is `NULL` (the default), all possible scoring rules are returned. +#' @param exclude A character vector of scoring rules to exclude from the list. +#' If `select` is not `NULL`, this argument is ignored. +#' @param ... unused +#' @return A list of scoring functions. +#' @export +#' @family `get_metrics` functions +#' @keywords handle-metrics +#' @examples +#' get_metrics(example_binary) +#' get_metrics(example_binary, select = "brier_score") +#' get_metrics(example_binary, exclude = "log_score") +get_metrics.forecast_binary <- function(x, select = NULL, exclude = NULL, ...) { + all <- list( + brier_score = brier_score, + log_score = logs_binary + ) + select_metrics(all, select, exclude) +} + + +#' Binary forecast example data +#' +#' A data set with binary predictions for COVID-19 cases and deaths constructed +#' from data submitted to the European Forecast Hub. +#' +#' Predictions in the data set were constructed based on the continuous example +#' data by looking at the number of samples below the mean prediction. +#' The outcome was constructed as whether or not the actually +#' observed value was below or above that mean prediction. +#' This should not be understood as sound statistical practice, but rather +#' as a practical way to create an example data set. +#' +#' The data was created using the script create-example-data.R in the inst/ +#' folder (or the top level folder in a compiled package). +#' +#' @format An object of class `forecast_binary` (see [as_forecast()]) with the +#' following columns: +#' \describe{ +#' \item{location}{the country for which a prediction was made} +#' \item{location_name}{name of the country for which a prediction was made} +#' \item{target_end_date}{the date for which a prediction was made} +#' \item{target_type}{the target to be predicted (cases or deaths)} +#' \item{observed}{A factor with observed values} +#' \item{forecast_date}{the date on which a prediction was made} +#' \item{model}{name of the model that generated the forecasts} +#' \item{horizon}{forecast horizon in weeks} +#' \item{predicted}{predicted value} +#' } +# nolint start +#' @source \url{https://github.com/european-modelling-hubs/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} +# nolint end +"example_binary" diff --git a/R/class-forecast-nominal.R b/R/class-forecast-nominal.R new file mode 100644 index 000000000..6b0bc1724 --- /dev/null +++ b/R/class-forecast-nominal.R @@ -0,0 +1,155 @@ +#' @title Create a `forecast` object for nominal forecasts +#' @description +#' Nominal forecasts are a form of categorical forecasts where the possible +#' outcomes that the observed values can assume are not ordered. In that sense, +#' Nominal forecasts represent a generalisation of binary forecasts. +#' @inheritParams as_forecast +#' @param predicted_label (optional) Name of the column in `data` that denotes +#' the outcome to which a predicted probability corresponds to. +#' This column will be renamed to "predicted_label". Only applicable to +#' nominal forecasts. +#' @family functions to create forecast objects +#' @keywords as_forecast +#' @export +as_forecast_nominal <- function(data, + forecast_unit = NULL, + observed = NULL, + predicted = NULL, + predicted_label = NULL) { + assert_character(predicted_label, len = 1, null.ok = TRUE) + assert_subset(predicted_label, names(data), empty.ok = TRUE) + if (!is.null(predicted_label)) { + setnames(data, old = predicted_label, new = "predicted_label") + } + + data <- as_forecast_generic(data, forecast_unit, observed, predicted) + data <- new_forecast(data, "forecast_nominal") + assert_forecast(data) + return(data) +} + + +#' @export +#' @keywords check-forecasts +#' @importFrom checkmate assert_names assert_set_equal test_set_equal +assert_forecast.forecast_nominal <- function( + forecast, forecast_type = NULL, verbose = TRUE, ... +) { + forecast <- assert_forecast_generic(forecast, verbose) + assert(check_columns_present(forecast, "predicted_label")) + assert_names( + colnames(forecast), + disjunct.from = c("sample_id", "quantile_level") + ) + assert_forecast_type(forecast, actual = "nominal", desired = forecast_type) + + # levels need to be the same + outcomes <- levels(forecast$observed) + assert_set_equal(levels(forecast$predicted_label), outcomes) + + # forecasts need to be complete + forecast_unit <- get_forecast_unit(forecast) + complete <- as.data.table(forecast)[, .( + correct = test_set_equal(as.character(predicted_label), outcomes) + ), by = forecast_unit] + + if (!all(complete$correct)) { + first_issue <- complete[(correct), ..forecast_unit][1] + first_issue <- lapply(first_issue, FUN = as.character) + #nolint start: keyword_quote_linter object_usage_linter duplicate_argument_linter + issue_location <- paste(names(first_issue), "==", first_issue) + cli_abort( + c(`!` = "Found incomplete forecasts", + `i` = "For a nominal forecast, all possible outcomes must be assigned + a probability explicitly.", + `i` = "Found first missing probabilities in the forecast identified by + {.emph {issue_location}}") + ) + #nolint end + } + return(forecast[]) +} + + +#' @export +#' @rdname is_forecast +is_forecast_nominal <- function(x) { + inherits(x, "forecast_nominal") && inherits(x, "forecast") +} + + +#' @importFrom stats na.omit +#' @importFrom data.table setattr +#' @rdname score +#' @export +score.forecast_nominal <- function(forecast, metrics = get_metrics(forecast), ...) { + forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE) + forecast_unit <- get_forecast_unit(forecast) + metrics <- validate_metrics(metrics) + forecast <- as.data.table(forecast) + + # transpose the forecasts that belong to the same forecast unit + # make sure the labels and predictions are ordered in the same way + f_transposed <- forecast[, .( + predicted = list(predicted[order(predicted_label)]), + observed = unique(observed) + ), by = forecast_unit] + + observed <- f_transposed$observed + predicted <- do.call(rbind, f_transposed$predicted) + predicted_label <- sort(unique(forecast$predicted_label, na.last = TRUE)) + f_transposed[, c("observed", "predicted") := NULL] + + scores <- apply_metrics( + f_transposed, metrics, + observed, predicted, predicted_label, ... + ) + scores <- as_scores(scores, metrics = names(metrics)) + return(scores[]) +} + + +#' Get default metrics for nominal forecasts +#' @inheritParams get_metrics.forecast_binary +#' @description +#' For nominal forecasts, the default scoring rule is: +#' - "log_score" = [logs_nominal()] +#' @export +#' @family `get_metrics` functions +#' @keywords handle-metrics +#' @examples +#' get_metrics(example_nominal) +get_metrics.forecast_nominal <- function(x, select = NULL, exclude = NULL, ...) { + all <- list( + log_score = logs_nominal + ) + select_metrics(all, select, exclude) +} + + +#' Nominal example data +#' +#' A data set with predictions for COVID-19 cases and deaths submitted to the +#' European Forecast Hub. +#' +#' The data was created using the script create-example-data.R in the inst/ +#' folder (or the top level folder in a compiled package). +#' +#' @format An object of class `forecast_nominal` (see [as_forecast()]) with the +#' following columns: +#' \describe{ +#' \item{location}{the country for which a prediction was made} +#' \item{target_end_date}{the date for which a prediction was made} +#' \item{target_type}{the target to be predicted (cases or deaths)} +#' \item{observed}{Numeric: observed values} +#' \item{location_name}{name of the country for which a prediction was made} +#' \item{forecast_date}{the date on which a prediction was made} +#' \item{predicted_label}{outcome that a probabilty corresponds to} +#' \item{predicted}{predicted value} +#' \item{model}{name of the model that generated the forecasts} +#' \item{horizon}{forecast horizon in weeks} +#' } +# nolint start +#' @source \url{https://github.com/european-modelling-hubs/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} +# nolint end +"example_nominal" diff --git a/R/class-forecast-point.R b/R/class-forecast-point.R new file mode 100644 index 000000000..9b06d75be --- /dev/null +++ b/R/class-forecast-point.R @@ -0,0 +1,165 @@ +#' @title Create a `forecast` object for point forecasts +#' @description +#' Create a `forecast` object for point forecasts. See more information on +#' forecast types and expected input formats by calling `?`[as_forecast()]. +#' @inherit as_forecast params +#' @param ... Unused +#' @family functions to create forecast objects +#' @export +#' @keywords as_forecast transform +as_forecast_point <- function(data, ...) { + UseMethod("as_forecast_point") +} + + +#' @rdname as_forecast_point +#' @export +#' @importFrom cli cli_warn +as_forecast_point.default <- function(data, + forecast_unit = NULL, + observed = NULL, + predicted = NULL, + ...) { + data <- as_forecast_generic(data, forecast_unit, observed, predicted) + data <- new_forecast(data, "forecast_point") + assert_forecast(data) + return(data) +} + + +#' @export +#' @rdname assert_forecast +#' @importFrom cli cli_abort +#' @keywords validate-forecast-object +assert_forecast.forecast_point <- function( + forecast, forecast_type = NULL, verbose = TRUE, ... +) { + forecast <- assert_forecast_generic(forecast, verbose) + assert_forecast_type(forecast, actual = "point", desired = forecast_type) + #nolint start: keyword_quote_linter object_usage_linter + input_check <- check_input_point(forecast$observed, forecast$predicted) + if (!isTRUE(input_check)) { + cli_abort( + c( + "!" = "Checking `forecast`: Input looks like a point forecast, but found + the following issue: {input_check}" + ) + ) + #nolint end + } + return(invisible(NULL)) +} + + +#' @export +#' @rdname is_forecast +is_forecast_point <- function(x) { + inherits(x, "forecast_point") && inherits(x, "forecast") +} + + +#' @importFrom Metrics se ae ape +#' @importFrom stats na.omit +#' @importFrom data.table setattr copy +#' @rdname score +#' @export +score.forecast_point <- function(forecast, metrics = get_metrics(forecast), ...) { + forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE) + metrics <- validate_metrics(metrics) + forecast <- as.data.table(forecast) + + scores <- apply_metrics( + forecast, metrics, + forecast$observed, forecast$predicted + ) + scores[, `:=`(predicted = NULL, observed = NULL)] + + scores <- as_scores(scores, metrics = names(metrics)) + return(scores[]) +} + + +#' Get default metrics for point forecasts +#' +#' @description +#' For point forecasts, the default scoring rules are: +#' - "ae_point" = [ae()][Metrics::ae()] +#' - "se_point" = [se()][Metrics::se()] +#' - "ape" = [ape()][Metrics::ape()] +#' +#' A note of caution: Every scoring rule for a point forecast +#' is implicitly minimised by a specific aspect of the predictive distribution +#' (see Gneiting, 2011). +#' +#' The mean squared error, for example, is only a meaningful scoring rule if +#' the forecaster actually reported the mean of their predictive distribution +#' as a point forecast. If the forecaster reported the median, then the mean +#' absolute error would be the appropriate scoring rule. If the scoring rule +#' and the predictive task do not align, the results will be misleading. +#' +#' Failure to respect this correspondence can lead to grossly misleading +#' results! Consider the example in the section below. +#' @inheritSection illustration-input-metric-binary-point Input format +#' @inheritParams get_metrics.forecast_binary +#' @export +#' @family `get_metrics` functions +#' @keywords handle-metrics +#' @examples +#' get_metrics(example_point, select = "ape") +#' +#' library(magrittr) +#' set.seed(123) +#' n <- 500 +#' observed <- rnorm(n, 5, 4)^2 +#' +#' predicted_mu <- mean(observed) +#' predicted_not_mu <- predicted_mu - rnorm(n, 10, 2) +#' +#' df <- data.frame( +#' model = rep(c("perfect", "bad"), each = n), +#' predicted = c(rep(predicted_mu, n), predicted_not_mu), +#' observed = rep(observed, 2), +#' id = rep(1:n, 2) +#' ) %>% +#' as_forecast_point() +#' score(df) %>% +#' summarise_scores() +#' @references +#' Making and Evaluating Point Forecasts, Gneiting, Tilmann, 2011, +#' Journal of the American Statistical Association. +get_metrics.forecast_point <- function(x, select = NULL, exclude = NULL, ...) { + all <- list( + ae_point = Metrics::ae, + se_point = Metrics::se, + ape = Metrics::ape + ) + select_metrics(all, select, exclude) +} + + +#' Point forecast example data +#' +#' A data set with predictions for COVID-19 cases and deaths submitted to the +#' European Forecast Hub. This data set is like the quantile example data, only +#' that the median has been replaced by a point forecast. +#' +#' The data was created using the script create-example-data.R in the inst/ +#' folder (or the top level folder in a compiled package). +#' +#' @format An object of class `forecast_point` (see [as_forecast()]) with the +#' following columns: +#' \describe{ +#' \item{location}{the country for which a prediction was made} +#' \item{target_end_date}{the date for which a prediction was made} +#' \item{target_type}{the target to be predicted (cases or deaths)} +#' \item{observed}{observed values} +#' \item{location_name}{name of the country for which a prediction was made} +#' \item{forecast_date}{the date on which a prediction was made} +#' \item{predicted}{predicted value} +#' \item{model}{name of the model that generated the forecasts} +#' \item{horizon}{forecast horizon in weeks} +#' } +# nolint start +#' @source \url{https://github.com/european-modelling-hubs/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} +# nolint end +"example_point" diff --git a/R/class-forecast-quantile.R b/R/class-forecast-quantile.R new file mode 100644 index 000000000..152121ee1 --- /dev/null +++ b/R/class-forecast-quantile.R @@ -0,0 +1,227 @@ +#' @title Create a `forecast` object for quantile-based forecasts +#' @description +#' Create a `forecast` object for quantile-based forecasts. See more information +#' on forecast types and expected input formats by calling `?`[as_forecast()]. +#' @param ... Unused +#' @family functions to create forecast objects +#' @inheritParams as_forecast +#' @export +#' @keywords as_forecast transform +as_forecast_quantile <- function(data, ...) { + UseMethod("as_forecast_quantile") +} + + +#' @rdname as_forecast_quantile +#' @param quantile_level (optional) Name of the column in `data` that contains +#' the quantile level of the predicted values. This column will be renamed to +#' "quantile_level". Only applicable to quantile-based forecasts. +#' @export +#' @importFrom cli cli_warn +as_forecast_quantile.default <- function(data, + forecast_unit = NULL, + observed = NULL, + predicted = NULL, + quantile_level = NULL, + ...) { + assert_character(quantile_level, len = 1, null.ok = TRUE) + assert_subset(quantile_level, names(data), empty.ok = TRUE) + if (!is.null(quantile_level)) { + setnames(data, old = quantile_level, new = "quantile_level") + } + + data <- as_forecast_generic(data, forecast_unit, observed, predicted) + data <- new_forecast(data, "forecast_quantile") + assert_forecast(data) + return(data) +} + + +#' @export +#' @rdname assert_forecast +#' @keywords validate-forecast-object +assert_forecast.forecast_quantile <- function( + forecast, forecast_type = NULL, verbose = TRUE, ... +) { + forecast <- assert_forecast_generic(forecast, verbose) + assert_forecast_type(forecast, actual = "quantile", desired = forecast_type) + assert_numeric(forecast$quantile_level, lower = 0, upper = 1) + return(invisible(NULL)) +} + + +#' @export +#' @rdname is_forecast +is_forecast_quantile <- function(x) { + inherits(x, "forecast_quantile") && inherits(x, "forecast") +} + + +#' @rdname as_forecast_point +#' @description +#' When converting a `forecast_quantile` object into a `forecast_point` object, +#' the 0.5 quantile is extracted and returned as the point forecast. +#' @export +#' @keywords as_forecast +as_forecast_point.forecast_quantile <- function(data, ...) { + assert_forecast(data, verbose = FALSE) + assert_subset(0.5, unique(data$quantile_level)) + + # At end of this function, the object will have be turned from a + # forecast_quantile to a forecast_point and we don't want to validate it as a + # forecast_point during the conversion process. The correct class is restored + # at the end. + data <- as.data.table(data) + + forecast <- data[quantile_level == 0.5] + forecast[, "quantile_level" := NULL] + + point_forecast <- new_forecast(forecast, "forecast_point") + return(point_forecast) +} + + +#' @importFrom stats na.omit +#' @importFrom data.table `:=` as.data.table rbindlist %like% setattr copy +#' @rdname score +#' @export +score.forecast_quantile <- function(forecast, metrics = get_metrics(forecast), ...) { + forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE) + forecast_unit <- get_forecast_unit(forecast) + metrics <- validate_metrics(metrics) + forecast <- as.data.table(forecast) + + # transpose the forecasts that belong to the same forecast unit + # make sure the quantiles and predictions are ordered in the same way + f_transposed <- forecast[, .( + predicted = list(predicted[order(quantile_level)]), + observed = unique(observed), + quantile_level = list(sort(quantile_level, na.last = TRUE)), + scoringutils_quantile_level = toString(sort(quantile_level, na.last = TRUE)) + ), by = forecast_unit] + + # split according to quantile_level lengths and do calculations for different + # quantile_level lengths separately. The function `wis()` assumes that all + # forecasts have the same quantile_levels + f_split <- split(f_transposed, f_transposed$scoringutils_quantile_level) + + split_result <- lapply(f_split, function(forecast) { + # create a matrix out of the list of predicted values and quantile_levels + observed <- forecast$observed + predicted <- do.call(rbind, forecast$predicted) + quantile_level <- unlist(unique(forecast$quantile_level)) + forecast[, c( + "observed", "predicted", "quantile_level", "scoringutils_quantile_level" + ) := NULL] + + forecast <- apply_metrics( + forecast, metrics, + observed, predicted, quantile_level + ) + return(forecast) + }) + scores <- rbindlist(split_result, fill = TRUE) + + scores <- as_scores(scores, metrics = names(metrics)) + + return(scores[]) +} + + +#' Get default metrics for quantile-based forecasts +#' +#' @description +#' For quantile-based forecasts, the default scoring rules are: +#' - "wis" = [wis()] +#' - "overprediction" = [overprediction_quantile()] +#' - "underprediction" = [underprediction_quantile()] +#' - "dispersion" = [dispersion_quantile()] +#' - "bias" = [bias_quantile()] +#' - "interval_coverage_50" = [interval_coverage()] +#' - "interval_coverage_90" = purrr::partial( +#' interval_coverage, interval_range = 90 +#' ) +#' - "ae_median" = [ae_median_quantile()] +#' +#' Note: The `interval_coverage_90` scoring rule is created by modifying +#' [interval_coverage()], making use of the function [purrr::partial()]. +#' This construct allows the function to deal with arbitrary arguments in `...`, +#' while making sure that only those that [interval_coverage()] can +#' accept get passed on to it. `interval_range = 90` is set in the function +#' definition, as passing an argument `interval_range = 90` to [score()] would +#' mean it would also get passed to `interval_coverage_50`. +#' @inheritSection illustration-input-metric-quantile Input format +#' @inheritParams get_metrics.forecast_binary +#' @export +#' @family `get_metrics` functions +#' @keywords handle-metrics +#' @importFrom purrr partial +#' @examples +#' get_metrics(example_quantile, select = "wis") +get_metrics.forecast_quantile <- function(x, select = NULL, exclude = NULL, ...) { + all <- list( + wis = wis, + overprediction = overprediction_quantile, + underprediction = underprediction_quantile, + dispersion = dispersion_quantile, + bias = bias_quantile, + interval_coverage_50 = interval_coverage, + interval_coverage_90 = purrr::partial( + interval_coverage, interval_range = 90 + ), + ae_median = ae_median_quantile + ) + select_metrics(all, select, exclude) +} + + +#' @rdname get_pit +#' @importFrom stats na.omit +#' @importFrom data.table `:=` as.data.table +#' @export +get_pit.forecast_quantile <- function(forecast, by, ...) { + forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE) + forecast <- as.data.table(forecast) + + forecast[, quantile_coverage := (observed <= predicted)] + quantile_coverage <- + forecast[, .(quantile_coverage = mean(quantile_coverage)), + by = c(unique(c(by, "quantile_level")))] + quantile_coverage <- quantile_coverage[ + order(quantile_level), + .( + quantile_level = c(quantile_level, 1), + pit_value = diff(c(0, quantile_coverage, 1)) + ), + by = c(get_forecast_unit(quantile_coverage)) + ] + return(quantile_coverage[]) +} + + +#' Quantile example data +#' +#' A data set with predictions for COVID-19 cases and deaths submitted to the +#' European Forecast Hub. +#' +#' The data was created using the script create-example-data.R in the inst/ +#' folder (or the top level folder in a compiled package). +#' +#' @format An object of class `forecast_quantile` (see [as_forecast()]) with the +#' following columns: +#' \describe{ +#' \item{location}{the country for which a prediction was made} +#' \item{target_end_date}{the date for which a prediction was made} +#' \item{target_type}{the target to be predicted (cases or deaths)} +#' \item{observed}{Numeric: observed values} +#' \item{location_name}{name of the country for which a prediction was made} +#' \item{forecast_date}{the date on which a prediction was made} +#' \item{quantile_level}{quantile level of the corresponding prediction} +#' \item{predicted}{predicted value} +#' \item{model}{name of the model that generated the forecasts} +#' \item{horizon}{forecast horizon in weeks} +#' } +# nolint start +#' @source \url{https://github.com/european-modelling-hubs/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} +# nolint end +"example_quantile" diff --git a/R/class-forecast-sample.R b/R/class-forecast-sample.R new file mode 100644 index 000000000..bb4ea569f --- /dev/null +++ b/R/class-forecast-sample.R @@ -0,0 +1,249 @@ +#' @title Create a `forecast` object for sample-based forecasts +#' @param sample_id (optional) Name of the column in `data` that contains the +#' sample id. This column will be renamed to "sample_id". Only applicable to +#' sample-based forecasts. +#' @inheritParams as_forecast +#' @export +#' @family functions to create forecast objects +#' @importFrom cli cli_warn +#' @keywords as_forecast +as_forecast_sample <- function(data, + forecast_unit = NULL, + observed = NULL, + predicted = NULL, + sample_id = NULL) { + assert_character(sample_id, len = 1, null.ok = TRUE) + assert_subset(sample_id, names(data), empty.ok = TRUE) + if (!is.null(sample_id)) { + setnames(data, old = sample_id, new = "sample_id") + } + + data <- as_forecast_generic(data, forecast_unit, observed, predicted) + data <- new_forecast(data, "forecast_sample") + assert_forecast(data) + return(data) +} + + +#' @export +#' @rdname assert_forecast +#' @keywords validate-forecast-object +assert_forecast.forecast_sample <- function( + forecast, forecast_type = NULL, verbose = TRUE, ... +) { + forecast <- assert_forecast_generic(forecast, verbose) + assert_forecast_type(forecast, actual = "sample", desired = forecast_type) + return(invisible(NULL)) +} + + +#' @export +#' @rdname is_forecast +is_forecast_sample <- function(x) { + inherits(x, "forecast_sample") && inherits(x, "forecast") +} + + +#' @rdname as_forecast_quantile +#' @description +#' When creating a `forecast_quantile` object from a `forecast_sample` object, +#' the quantiles are estimated by computing empircal quantiles from the samples +#' via [quantile()]. Note that empirical quantiles are a biased estimator for +#' the true quantiles in particular in the tails of the distribution and +#' when the number of available samples is low. +#' @param probs A numeric vector of quantile levels for which +#' quantiles will be computed. Corresponds to the `probs` argument in +#' [quantile()]. +#' @param type Type argument passed down to the quantile function. For more +#' information, see [quantile()]. +#' @importFrom stats quantile +#' @importFrom methods hasArg +#' @importFrom checkmate assert_numeric +#' @export +as_forecast_quantile.forecast_sample <- function( + data, + probs = c(0.05, 0.25, 0.5, 0.75, 0.95), + type = 7, + ... +) { + forecast <- copy(data) + assert_forecast(forecast, verbose = FALSE) + assert_numeric(probs, min.len = 1) + reserved_columns <- c("predicted", "sample_id") + by <- setdiff(colnames(forecast), reserved_columns) + + quantile_level <- unique( + round(c(probs, 1 - probs), digits = 10) + ) + + forecast <- + forecast[, .(quantile_level = quantile_level, + predicted = quantile(x = predicted, probs = ..probs, + type = ..type, na.rm = TRUE)), + by = by] + + quantile_forecast <- new_forecast(forecast, "forecast_quantile") + assert_forecast(quantile_forecast) + + return(quantile_forecast) +} + + +#' @importFrom stats na.omit +#' @importFrom data.table setattr copy +#' @rdname score +#' @export +score.forecast_sample <- function(forecast, metrics = get_metrics(forecast), ...) { + forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE) + forecast_unit <- get_forecast_unit(forecast) + metrics <- validate_metrics(metrics) + forecast <- as.data.table(forecast) + + # transpose the forecasts that belong to the same forecast unit + f_transposed <- forecast[, .(predicted = list(predicted), + observed = unique(observed), + scoringutils_N = length(list(sample_id))), + by = forecast_unit] + + # split according to number of samples and do calculations for different + # sample lengths separately + f_split <- split(f_transposed, f_transposed$scoringutils_N) + + split_result <- lapply(f_split, function(forecast) { + # create a matrix + observed <- forecast$observed + predicted <- do.call(rbind, forecast$predicted) + forecast[, c("observed", "predicted", "scoringutils_N") := NULL] + + forecast <- apply_metrics( + forecast, metrics, + observed, predicted + ) + return(forecast) + }) + scores <- rbindlist(split_result, fill = TRUE) + scores <- as_scores(scores, metrics = names(metrics)) + return(scores[]) +} + + +#' Get default metrics for sample-based forecasts +#' +#' @description +#' For sample-based forecasts, the default scoring rules are: +#' - "crps" = [crps_sample()] +#' - "overprediction" = [overprediction_sample()] +#' - "underprediction" = [underprediction_sample()] +#' - "dispersion" = [dispersion_sample()] +#' - "log_score" = [logs_sample()] +#' - "dss" = [dss_sample()] +#' - "mad" = [mad_sample()] +#' - "bias" = [bias_sample()] +#' - "ae_median" = [ae_median_sample()] +#' - "se_mean" = [se_mean_sample()] +#' @inheritSection illustration-input-metric-sample Input format +#' @inheritParams get_metrics.forecast_binary +#' @export +#' @family `get_metrics` functions +#' @keywords handle-metrics +#' @examples +#' get_metrics(example_sample_continuous, exclude = "mad") +get_metrics.forecast_sample <- function(x, select = NULL, exclude = NULL, ...) { + all <- list( + bias = bias_sample, + dss = dss_sample, + crps = crps_sample, + overprediction = overprediction_sample, + underprediction = underprediction_sample, + dispersion = dispersion_sample, + log_score = logs_sample, + mad = mad_sample, + ae_median = ae_median_sample, + se_mean = se_mean_sample + ) + select_metrics(all, select, exclude) +} + + +#' @rdname get_pit +#' @importFrom stats na.omit +#' @importFrom data.table `:=` as.data.table dcast +#' @inheritParams pit_sample +#' @export +get_pit.forecast_sample <- function(forecast, by, n_replicates = 100, ...) { + forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE) + forecast <- as.data.table(forecast) + + # if prediction type is not quantile, calculate PIT values based on samples + forecast_wide <- data.table::dcast( + forecast, + ... ~ paste0("InternalSampl_", sample_id), + value.var = "predicted" + ) + + pit <- forecast_wide[, .(pit_value = pit_sample( + observed = observed, + predicted = as.matrix(.SD) + )), + by = by, + .SDcols = grepl("InternalSampl_", names(forecast_wide), fixed = TRUE) + ] + + return(pit[]) +} + + +#' Continuous forecast example data +#' +#' A data set with continuous predictions for COVID-19 cases and deaths +#' constructed from data submitted to the European Forecast Hub. +#' +#' The data was created using the script create-example-data.R in the inst/ +#' folder (or the top level folder in a compiled package). +#' +#' @format An object of class `forecast_sample` (see [as_forecast()]) with the +#' following columns: +#' \describe{ +#' \item{location}{the country for which a prediction was made} +#' \item{target_end_date}{the date for which a prediction was made} +#' \item{target_type}{the target to be predicted (cases or deaths)} +#' \item{observed}{observed values} +#' \item{location_name}{name of the country for which a prediction was made} +#' \item{forecast_date}{the date on which a prediction was made} +#' \item{model}{name of the model that generated the forecasts} +#' \item{horizon}{forecast horizon in weeks} +#' \item{predicted}{predicted value} +#' \item{sample_id}{id for the corresponding sample} +#' } +# nolint start +#' @source \url{https://github.com/european-modelling-hubs/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} +# nolint end +"example_sample_continuous" + + +#' Discrete forecast example data +#' +#' A data set with integer predictions for COVID-19 cases and deaths +#' constructed from data submitted to the European Forecast Hub. +#' +#' The data was created using the script create-example-data.R in the inst/ +#' folder (or the top level folder in a compiled package). +#' +#' @format An object of class `forecast_sample` (see [as_forecast()]) with the +#' following columns: +#' \describe{ +#' \item{location}{the country for which a prediction was made} +#' \item{target_end_date}{the date for which a prediction was made} +#' \item{target_type}{the target to be predicted (cases or deaths)} +#' \item{observed}{observed values} +#' \item{location_name}{name of the country for which a prediction was made} +#' \item{forecast_date}{the date on which a prediction was made} +#' \item{model}{name of the model that generated the forecasts} +#' \item{horizon}{forecast horizon in weeks} +#' \item{predicted}{predicted value} +#' \item{sample_id}{id for the corresponding sample} +#' } +# nolint start +#' @source \url{https://github.com/european-modelling-hubs/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} +# nolint end +"example_sample_discrete" diff --git a/R/forecast.R b/R/class-forecast.R similarity index 50% rename from R/forecast.R rename to R/class-forecast.R index 07838de61..b78010066 100644 --- a/R/forecast.R +++ b/R/class-forecast.R @@ -1,7 +1,7 @@ #' @title General information on creating a `forecast` object #' #' @description -#' There are several ``as_forecast_()` functions to process and validate +#' There are several `as_forecast_()` functions to process and validate #' a data.frame (or similar) or similar with forecasts and observations. If #' the input passes all input checks, those functions will be converted #' to a `forecast` object. A forecast object is a `data.table` with a @@ -94,221 +94,6 @@ as_forecast_generic <- function(data, } -#' @title Create a `forecast` object for binary forecasts -#' @description -#' Create a `forecast` object for binary forecasts. See more information on -#' forecast types and expected input formats by calling `?`[as_forecast()]. -#' @export -#' @inheritParams as_forecast -#' @family functions to create forecast objects -#' @importFrom cli cli_warn -#' @keywords as_forecast -as_forecast_binary <- function(data, - forecast_unit = NULL, - observed = NULL, - predicted = NULL) { - data <- as_forecast_generic(data, forecast_unit, observed, predicted) - data <- new_forecast(data, "forecast_binary") - assert_forecast(data) - return(data) -} - - -#' @title Create a `forecast` object for point forecasts -#' @description -#' Create a `forecast` object for point forecasts. See more information on -#' forecast types and expected input formats by calling `?`[as_forecast()]. -#' @inherit as_forecast params -#' @param ... Unused -#' @family functions to create forecast objects -#' @export -#' @keywords as_forecast transform -as_forecast_point <- function(data, ...) { - UseMethod("as_forecast_point") -} - - -#' @rdname as_forecast_point -#' @export -#' @importFrom cli cli_warn -as_forecast_point.default <- function(data, - forecast_unit = NULL, - observed = NULL, - predicted = NULL, - ...) { - data <- as_forecast_generic(data, forecast_unit, observed, predicted) - data <- new_forecast(data, "forecast_point") - assert_forecast(data) - return(data) -} - - -#' @rdname as_forecast_point -#' @description -#' When converting a `forecast_quantile` object into a `forecast_point` object, -#' the 0.5 quantile is extracted and returned as the point forecast. -#' @export -#' @keywords as_forecast -as_forecast_point.forecast_quantile <- function(data, ...) { - assert_forecast(data, verbose = FALSE) - assert_subset(0.5, unique(data$quantile_level)) - - # At end of this function, the object will have be turned from a - # forecast_quantile to a forecast_point and we don't want to validate it as a - # forecast_point during the conversion process. The correct class is restored - # at the end. - data <- as.data.table(data) - - forecast <- data[quantile_level == 0.5] - forecast[, "quantile_level" := NULL] - - point_forecast <- new_forecast(forecast, "forecast_point") - return(point_forecast) -} - - -#' @title Create a `forecast` object for quantile-based forecasts -#' @description -#' Create a `forecast` object for quantile-based forecasts. See more information -#' on forecast types and expected input formats by calling `?`[as_forecast()]. -#' @param ... Unused -#' @family functions to create forecast objects -#' @inheritParams as_forecast -#' @export -#' @keywords as_forecast transform -as_forecast_quantile <- function(data, ...) { - UseMethod("as_forecast_quantile") -} - - -#' @rdname as_forecast_quantile -#' @param quantile_level (optional) Name of the column in `data` that contains -#' the quantile level of the predicted values. This column will be renamed to -#' "quantile_level". Only applicable to quantile-based forecasts. -#' @export -#' @importFrom cli cli_warn -as_forecast_quantile.default <- function(data, - forecast_unit = NULL, - observed = NULL, - predicted = NULL, - quantile_level = NULL, - ...) { - assert_character(quantile_level, len = 1, null.ok = TRUE) - assert_subset(quantile_level, names(data), empty.ok = TRUE) - if (!is.null(quantile_level)) { - setnames(data, old = quantile_level, new = "quantile_level") - } - - data <- as_forecast_generic(data, forecast_unit, observed, predicted) - data <- new_forecast(data, "forecast_quantile") - assert_forecast(data) - return(data) -} - - -#' @rdname as_forecast_quantile -#' @description -#' When creating a `forecast_quantile` object from a `forecast_sample` object, -#' the quantiles are estimated by computing empircal quantiles from the samples -#' via [quantile()]. Note that empirical quantiles are a biased estimator for -#' the true quantiles in particular in the tails of the distribution and -#' when the number of available samples is low. -#' @param probs A numeric vector of quantile levels for which -#' quantiles will be computed. Corresponds to the `probs` argument in -#' [quantile()]. -#' @param type Type argument passed down to the quantile function. For more -#' information, see [quantile()]. -#' @importFrom stats quantile -#' @importFrom methods hasArg -#' @importFrom checkmate assert_numeric -#' @export -as_forecast_quantile.forecast_sample <- function( - data, - probs = c(0.05, 0.25, 0.5, 0.75, 0.95), - type = 7, - ... -) { - forecast <- copy(data) - assert_forecast(forecast, verbose = FALSE) - assert_numeric(probs, min.len = 1) - reserved_columns <- c("predicted", "sample_id") - by <- setdiff(colnames(forecast), reserved_columns) - - quantile_level <- unique( - round(c(probs, 1 - probs), digits = 10) - ) - - forecast <- - forecast[, .(quantile_level = quantile_level, - predicted = quantile(x = predicted, probs = ..probs, - type = ..type, na.rm = TRUE)), - by = by] - - quantile_forecast <- new_forecast(forecast, "forecast_quantile") - assert_forecast(quantile_forecast) - - return(quantile_forecast) -} - - -#' @title Create a `forecast` object for sample-based forecasts -#' @param sample_id (optional) Name of the column in `data` that contains the -#' sample id. This column will be renamed to "sample_id". Only applicable to -#' sample-based forecasts. -#' @inheritParams as_forecast -#' @export -#' @family functions to create forecast objects -#' @importFrom cli cli_warn -#' @keywords as_forecast -as_forecast_sample <- function(data, - forecast_unit = NULL, - observed = NULL, - predicted = NULL, - sample_id = NULL) { - assert_character(sample_id, len = 1, null.ok = TRUE) - assert_subset(sample_id, names(data), empty.ok = TRUE) - if (!is.null(sample_id)) { - setnames(data, old = sample_id, new = "sample_id") - } - - data <- as_forecast_generic(data, forecast_unit, observed, predicted) - data <- new_forecast(data, "forecast_sample") - assert_forecast(data) - return(data) -} - - -#' @title Create a `forecast` object for nominal forecasts -#' @description -#' Nominal forecasts are a form of categorical forecasts where the possible -#' outcomes that the observed values can assume are not ordered. In that sense, -#' Nominal forecasts represent a generalisation of binary forecasts. -#' @inheritParams as_forecast -#' @param predicted_label (optional) Name of the column in `data` that denotes -#' the outcome to which a predicted probability corresponds to. -#' This column will be renamed to "predicted_label". Only applicable to -#' nominal forecasts. -#' @family functions to create forecast objects -#' @keywords as_forecast -#' @export -as_forecast_nominal <- function(data, - forecast_unit = NULL, - observed = NULL, - predicted = NULL, - predicted_label = NULL) { - assert_character(predicted_label, len = 1, null.ok = TRUE) - assert_subset(predicted_label, names(data), empty.ok = TRUE) - if (!is.null(predicted_label)) { - setnames(data, old = predicted_label, new = "predicted_label") - } - - data <- as_forecast_generic(data, forecast_unit, observed, predicted) - data <- new_forecast(data, "forecast_nominal") - assert_forecast(data) - return(data) -} - - #' @title Assert that input is a forecast object and passes validations #' #' @description @@ -356,134 +141,6 @@ assert_forecast.default <- function( } -#' @export -#' @rdname assert_forecast -#' @importFrom cli cli_abort -#' @keywords validate-forecast-object -assert_forecast.forecast_binary <- function( - forecast, forecast_type = NULL, verbose = TRUE, ... -) { - forecast <- assert_forecast_generic(forecast, verbose) - assert_forecast_type(forecast, actual = "binary", desired = forecast_type) - - columns_correct <- test_columns_not_present( - forecast, c("sample_id", "quantile_level") - ) - if (!columns_correct) { - #nolint start: keyword_quote_linter - cli_abort( - c( - "!" = "Checking `forecast`: Input looks like a binary forecast, but an - additional column called `sample_id` or `quantile` was found.", - "i" = "Please remove the column." - ) - ) - } - input_check <- check_input_binary(forecast$observed, forecast$predicted) - if (!isTRUE(input_check)) { - cli_abort( - c( - "!" = "Checking `forecast`: Input looks like a binary forecast, but - found the following issue: {input_check}" - ) - ) - #nolint end - } - return(invisible(NULL)) -} - - -#' @export -#' @rdname assert_forecast -#' @importFrom cli cli_abort -#' @keywords validate-forecast-object -assert_forecast.forecast_point <- function( - forecast, forecast_type = NULL, verbose = TRUE, ... -) { - forecast <- assert_forecast_generic(forecast, verbose) - assert_forecast_type(forecast, actual = "point", desired = forecast_type) - #nolint start: keyword_quote_linter object_usage_linter - input_check <- check_input_point(forecast$observed, forecast$predicted) - if (!isTRUE(input_check)) { - cli_abort( - c( - "!" = "Checking `forecast`: Input looks like a point forecast, but found - the following issue: {input_check}" - ) - ) - #nolint end - } - return(invisible(NULL)) -} - - -#' @export -#' @rdname assert_forecast -#' @keywords validate-forecast-object -assert_forecast.forecast_quantile <- function( - forecast, forecast_type = NULL, verbose = TRUE, ... -) { - forecast <- assert_forecast_generic(forecast, verbose) - assert_forecast_type(forecast, actual = "quantile", desired = forecast_type) - assert_numeric(forecast$quantile_level, lower = 0, upper = 1) - return(invisible(NULL)) -} - - -#' @export -#' @rdname assert_forecast -#' @keywords validate-forecast-object -assert_forecast.forecast_sample <- function( - forecast, forecast_type = NULL, verbose = TRUE, ... -) { - forecast <- assert_forecast_generic(forecast, verbose) - assert_forecast_type(forecast, actual = "sample", desired = forecast_type) - return(invisible(NULL)) -} - - -#' @export -#' @keywords check-forecasts -#' @importFrom checkmate assert_names assert_set_equal test_set_equal -assert_forecast.forecast_nominal <- function( - forecast, forecast_type = NULL, verbose = TRUE, ... -) { - forecast <- assert_forecast_generic(forecast, verbose) - assert(check_columns_present(forecast, "predicted_label")) - assert_names( - colnames(forecast), - disjunct.from = c("sample_id", "quantile_level") - ) - assert_forecast_type(forecast, actual = "nominal", desired = forecast_type) - - # levels need to be the same - outcomes <- levels(forecast$observed) - assert_set_equal(levels(forecast$predicted_label), outcomes) - - # forecasts need to be complete - forecast_unit <- get_forecast_unit(forecast) - complete <- as.data.table(forecast)[, .( - correct = test_set_equal(as.character(predicted_label), outcomes) - ), by = forecast_unit] - - if (!all(complete$correct)) { - first_issue <- complete[(correct), ..forecast_unit][1] - first_issue <- lapply(first_issue, FUN = as.character) - #nolint start: keyword_quote_linter object_usage_linter duplicate_argument_linter - issue_location <- paste(names(first_issue), "==", first_issue) - cli_abort( - c(`!` = "Found incomplete forecasts", - `i` = "For a nominal forecast, all possible outcomes must be assigned - a probability explicitly.", - `i` = "Found first missing probabilities in the forecast identified by - {.emph {issue_location}}") - ) - #nolint end - } - return(forecast[]) -} - - #' @title Validation common to all forecast types #' #' @description @@ -553,6 +210,41 @@ assert_forecast_generic <- function(data, verbose = TRUE) { } +#' Check that all forecasts have the same number of rows +#' @description +#' Helper function that checks the number of rows (corresponding e.g to +#' quantiles or samples) per forecast. +#' If the number of quantiles or samples is the same for all forecasts, it +#' returns TRUE and a string with an error message otherwise. +#' @param forecast_unit Character vector denoting the unit of a single forecast. +#' @importFrom checkmate assert_subset +#' @inherit document_check_functions params return +#' @keywords internal_input_check +check_number_per_forecast <- function(data, forecast_unit) { + # This function doesn't return a forecast object so it's fine to unclass it + # to avoid validation error while subsetting + data <- as.data.table(data) + data <- na.omit(data) + # check whether there are the same number of quantiles, samples -------------- + data[, scoringutils_InternalNumCheck := length(predicted), by = forecast_unit] + n <- unique(data$scoringutils_InternalNumCheck) + data[, scoringutils_InternalNumCheck := NULL] + if (length(n) > 1) { + msg <- paste0( + "Some forecasts have different numbers of rows ", + "(e.g. quantiles or samples). ", + "scoringutils found: ", toString(n), + ". This may be a problem (it can potentially distort scores, ", + "making it more difficult to compare them), ", + "so make sure this is intended." + ) + return(msg) + } + return(TRUE) +} + + + #' Clean forecast object #' @description #' The function makes it possible to silently validate an object. In addition, @@ -623,35 +315,6 @@ is_forecast <- function(x) { inherits(x, "forecast") } -#' @export -#' @rdname is_forecast -is_forecast_sample <- function(x) { - inherits(x, "forecast_sample") && inherits(x, "forecast") -} - -#' @export -#' @rdname is_forecast -is_forecast_binary <- function(x) { - inherits(x, "forecast_binary") && inherits(x, "forecast") -} - -#' @export -#' @rdname is_forecast -is_forecast_point <- function(x) { - inherits(x, "forecast_point") && inherits(x, "forecast") -} - -#' @export -#' @rdname is_forecast -is_forecast_quantile <- function(x) { - inherits(x, "forecast_quantile") && inherits(x, "forecast") -} - -#' @export -#' @rdname is_forecast -is_forecast_nominal <- function(x) { - inherits(x, "forecast_nominal") && inherits(x, "forecast") -} #' @export `[.forecast` <- function(x, ...) { @@ -687,6 +350,7 @@ is_forecast_nominal <- function(x) { } + #' @export `$<-.forecast` <- function(x, ..., value) { @@ -709,6 +373,7 @@ is_forecast_nominal <- function(x) { } + #' @export `[[<-.forecast` <- function(x, ..., value) { @@ -731,6 +396,7 @@ is_forecast_nominal <- function(x) { } + #' @export `[<-.forecast` <- function(x, ..., value) { @@ -753,6 +419,7 @@ is_forecast_nominal <- function(x) { } + #' @export #' @importFrom utils head head.forecast <- function(x, ...) { @@ -761,6 +428,7 @@ head.forecast <- function(x, ...) { head(as.data.table(x), ...) } + #' @export #' @importFrom utils tail tail.forecast <- function(x, ...) { @@ -768,3 +436,72 @@ tail.forecast <- function(x, ...) { # validation when we expect (and don't care) that objects are invalidated utils::tail(as.data.table(x), ...) } + + +#' @title Print information about a forecast object +#' @description +#' This function prints information about a forecast object, +#' including "Forecast type", "Score columns", +#' "Forecast unit". +#' +#' @param x A forecast object (a validated data.table with predicted and +#' observed values, see [as_forecast()]). +#' @param ... Additional arguments for [print()]. +#' @return Returns `x` invisibly. +#' @importFrom cli cli_inform cli_warn col_blue cli_text +#' @export +#' @keywords gain-insights +#' @examples +#' dat <- as_forecast_quantile(example_quantile) +#' print(dat) +print.forecast <- function(x, ...) { + + # get forecast type, forecast unit and score columns + forecast_type <- try( + do.call(get_forecast_type, list(forecast = x)), + silent = TRUE + ) + forecast_unit <- try( + do.call(get_forecast_unit, list(data = x)), + silent = TRUE + ) + + # Print forecast object information + if (inherits(forecast_type, "try-error")) { + cli_inform( + c( + "!" = "Could not determine forecast type due to error in validation." #nolint + ) + ) + } else { + cli_text( + col_blue( + "Forecast type: " + ), + "{forecast_type}" + ) + } + + if (inherits(forecast_unit, "try-error")) { + cli_inform( + c( + "!" = "Could not determine forecast unit." #nolint + ) + ) + } else { + cli_text( + col_blue( + "Forecast unit:" + ) + ) + cli_text( + "{forecast_unit}" + ) + } + + cat("\n") + + NextMethod() + + return(invisible(x)) +} diff --git a/R/class-scores.R b/R/class-scores.R new file mode 100644 index 000000000..6cc72be9b --- /dev/null +++ b/R/class-scores.R @@ -0,0 +1,142 @@ +#' Construct an object of class `scores` +#' @description +#' This function creates an object of class `scores` based on a +#' data.table or similar. +#' @param scores A data.table or similar with scores as produced by [score()]. +#' @param metrics A character vector with the names of the scores +#' (i.e. the names of the scoring rules used for scoring). +#' @param ... Additional arguments to [data.table::as.data.table()] +#' @keywords internal +#' @importFrom data.table as.data.table setattr +#' @return An object of class `scores` +#' @examples +#' \dontrun{ +#' df <- data.frame( +#' model = "A", +#' wis = "0.1" +#' ) +#' new_scores(df, "wis") +#' } +new_scores <- function(scores, metrics, ...) { + scores <- as.data.table(scores, ...) + class(scores) <- c("scores", class(scores)) + setattr(scores, "metrics", metrics) + return(scores[]) +} + + +#' Create an object of class `scores` from data +#' @description This convenience function wraps [new_scores()] and validates +#' the `scores` object. +#' @inherit new_scores params return +#' @importFrom checkmate assert_data_frame +#' @keywords internal +as_scores <- function(scores, metrics) { + assert_data_frame(scores) + present_metrics <- metrics[metrics %in% colnames(scores)] + scores <- new_scores(scores, present_metrics) + assert_scores(scores) + return(scores[]) +} + + +#' Validate an object of class `scores` +#' @description +#' This function validates an object of class `scores`, checking +#' that it has the correct class and that it has a `metrics` attribute. +#' @inheritParams new_scores +#' @returns Returns `NULL` invisibly +#' @importFrom checkmate assert_class assert_data_frame +#' @keywords internal +assert_scores <- function(scores) { + assert_data_frame(scores) + assert_class(scores, "scores") + # error if no metrics exists + + # throw warning if any of the metrics is not in the data + get_metrics.scores(scores, error = TRUE) + return(invisible(NULL)) +} + +#' @method `[` scores +#' @importFrom data.table setattr +#' @export +`[.scores` <- function(x, ...) { + ret <- NextMethod() + if (is.data.table(ret)) { + setattr(ret, "metrics", attr(x, "metrics")) + } else if (is.data.frame(ret)) { + attr(ret, "metrics") <- attr(x, "metrics") + } + return(ret) +} + + +#' @title Get names of the metrics that were used for scoring +#' @description +#' When applying a scoring rule via [score()], the names of the scoring rules +#' become column names of the +#' resulting data.table. In addition, an attribute `metrics` will be +#' added to the output, holding the names of the scores as a vector. +#' +#' This is done so that functions like [get_forecast_unit()] or +#' [summarise_scores()] can still identify which columns are part of the +#' forecast unit and which hold a score. +#' +#' `get_metrics()` accesses and returns the `metrics` attribute. If there is no +#' attribute, the function will return `NULL` (or, if `error = TRUE` will +#' produce an error instead). In addition, it checks the column names of the +#' input for consistency with the data stored in the `metrics` attribute. +#' +#' **Handling a missing or inconsistent `metrics` attribute**: +#' +#' If the metrics attribute is missing or is not consistent with the column +#' names of the data.table, you can either +#' +#' - run [score()] again, specifying names for the scoring rules manually, or +#' - add/update the attribute manually using +#' `attr(scores, "metrics") <- c("names", "of", "your", "scores")` (the +#' order does not matter). +#' +#' @param x A `scores` object, (a data.table with an attribute `metrics` as +#' produced by [score()]). +#' @param error Throw an error if there is no attribute called `metrics`? +#' Default is FALSE. +#' @param ... unused +#' @importFrom cli cli_abort cli_warn +#' @importFrom checkmate assert_data_frame +#' @return +#' Character vector with the names of the scoring rules that were used +#' for scoring. +#' @keywords handle-metrics +#' @family `get_metrics` functions +#' @export +get_metrics.scores <- function(x, error = FALSE, ...) { + assert_data_frame(x) + metrics <- attr(x, "metrics") + if (error && is.null(metrics)) { + #nolint start: keyword_quote_linter + cli_abort( + c( + "!" = "Input needs an attribute `metrics` with the names of the + scoring rules that were used for scoring.", + "i" = "See `?get_metrics` for further information." + ) + ) + #nolint end + } + + if (!all(metrics %in% names(x))) { + #nolint start: keyword_quote_linter object_usage_linter + missing <- setdiff(metrics, names(x)) + cli_warn( + c( + "!" = "The following scores have been previously computed, but are no + longer column names of the data: {.val {missing}}", + "i" = "See {.code ?get_metrics} for further information." + ) + ) + #nolint end + } + + return(metrics) +} diff --git a/R/correlations.R b/R/correlations.R deleted file mode 100644 index a71eaa644..000000000 --- a/R/correlations.R +++ /dev/null @@ -1,47 +0,0 @@ -#' @title Calculate correlation between metrics -#' -#' @description -#' Calculate the correlation between different metrics for a data.frame of -#' scores as produced by [score()]. -#' -#' @param metrics A character vector with the metrics to show. If set to -#' `NULL` (default), all metrics present in `scores` will be shown. -#' @inheritParams get_pairwise_comparisons -#' @param ... Additional arguments to pass down to [cor()]. -#' @return -#' An object of class `scores` (a data.table with an additional -#' attribute `metrics` holding the names of the scores) with correlations -#' between different metrics -#' @importFrom data.table setDT -#' @importFrom stats cor na.omit -#' @importFrom cli cli_warn -#' @importFrom checkmate assert_subset -#' @export -#' @keywords scoring -#' @examples -#' library(magrittr) # pipe operator -#' -#' scores <- example_quantile %>% -#' as_forecast_quantile() %>% -#' score() -#' -#' get_correlations(scores) -get_correlations <- function(scores, - metrics = get_metrics.scores(scores), - ...) { - scores <- ensure_data.table(scores) - assert_subset(metrics, colnames(scores), empty.ok = FALSE) - df <- scores[, .SD, .SDcols = names(scores) %in% metrics] - - # define correlation matrix - cor_mat <- cor(as.matrix(df), ...) - - correlations <- new_scores( - as.data.frame((cor_mat)), - metrics = metrics, - keep.rownames = TRUE - ) - correlations <- copy(correlations)[, metric := rn][, rn := NULL] - - return(correlations[]) -} diff --git a/R/data.R b/R/data.R deleted file mode 100644 index 0a3bd72e8..000000000 --- a/R/data.R +++ /dev/null @@ -1,172 +0,0 @@ -#' Quantile example data -#' -#' A data set with predictions for COVID-19 cases and deaths submitted to the -#' European Forecast Hub. -#' -#' The data was created using the script create-example-data.R in the inst/ -#' folder (or the top level folder in a compiled package). -#' -#' @format An object of class `forecast_quantile` (see [as_forecast()]) with the -#' following columns: -#' \describe{ -#' \item{location}{the country for which a prediction was made} -#' \item{target_end_date}{the date for which a prediction was made} -#' \item{target_type}{the target to be predicted (cases or deaths)} -#' \item{observed}{Numeric: observed values} -#' \item{location_name}{name of the country for which a prediction was made} -#' \item{forecast_date}{the date on which a prediction was made} -#' \item{quantile_level}{quantile level of the corresponding prediction} -#' \item{predicted}{predicted value} -#' \item{model}{name of the model that generated the forecasts} -#' \item{horizon}{forecast horizon in weeks} -#' } -# nolint start -#' @source \url{https://github.com/european-modelling-hubs/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} -# nolint end -"example_quantile" - - -#' Point forecast example data -#' -#' A data set with predictions for COVID-19 cases and deaths submitted to the -#' European Forecast Hub. This data set is like the quantile example data, only -#' that the median has been replaced by a point forecast. -#' -#' The data was created using the script create-example-data.R in the inst/ -#' folder (or the top level folder in a compiled package). -#' -#' @format An object of class `forecast_point` (see [as_forecast()]) with the -#' following columns: -#' \describe{ -#' \item{location}{the country for which a prediction was made} -#' \item{target_end_date}{the date for which a prediction was made} -#' \item{target_type}{the target to be predicted (cases or deaths)} -#' \item{observed}{observed values} -#' \item{location_name}{name of the country for which a prediction was made} -#' \item{forecast_date}{the date on which a prediction was made} -#' \item{predicted}{predicted value} -#' \item{model}{name of the model that generated the forecasts} -#' \item{horizon}{forecast horizon in weeks} -#' } -# nolint start -#' @source \url{https://github.com/european-modelling-hubs/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} -# nolint end -"example_point" - - -#' Continuous forecast example data -#' -#' A data set with continuous predictions for COVID-19 cases and deaths -#' constructed from data submitted to the European Forecast Hub. -#' -#' The data was created using the script create-example-data.R in the inst/ -#' folder (or the top level folder in a compiled package). -#' -#' @format An object of class `forecast_sample` (see [as_forecast()]) with the -#' following columns: -#' \describe{ -#' \item{location}{the country for which a prediction was made} -#' \item{target_end_date}{the date for which a prediction was made} -#' \item{target_type}{the target to be predicted (cases or deaths)} -#' \item{observed}{observed values} -#' \item{location_name}{name of the country for which a prediction was made} -#' \item{forecast_date}{the date on which a prediction was made} -#' \item{model}{name of the model that generated the forecasts} -#' \item{horizon}{forecast horizon in weeks} -#' \item{predicted}{predicted value} -#' \item{sample_id}{id for the corresponding sample} -#' } -# nolint start -#' @source \url{https://github.com/european-modelling-hubs/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} -# nolint end -"example_sample_continuous" - - -#' Discrete forecast example data -#' -#' A data set with integer predictions for COVID-19 cases and deaths -#' constructed from data submitted to the European Forecast Hub. -#' -#' The data was created using the script create-example-data.R in the inst/ -#' folder (or the top level folder in a compiled package). -#' -#' @format An object of class `forecast_sample` (see [as_forecast()]) with the -#' following columns: -#' \describe{ -#' \item{location}{the country for which a prediction was made} -#' \item{target_end_date}{the date for which a prediction was made} -#' \item{target_type}{the target to be predicted (cases or deaths)} -#' \item{observed}{observed values} -#' \item{location_name}{name of the country for which a prediction was made} -#' \item{forecast_date}{the date on which a prediction was made} -#' \item{model}{name of the model that generated the forecasts} -#' \item{horizon}{forecast horizon in weeks} -#' \item{predicted}{predicted value} -#' \item{sample_id}{id for the corresponding sample} -#' } -# nolint start -#' @source \url{https://github.com/european-modelling-hubs/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} -# nolint end -"example_sample_discrete" - - -#' Binary forecast example data -#' -#' A data set with binary predictions for COVID-19 cases and deaths constructed -#' from data submitted to the European Forecast Hub. -#' -#' Predictions in the data set were constructed based on the continuous example -#' data by looking at the number of samples below the mean prediction. -#' The outcome was constructed as whether or not the actually -#' observed value was below or above that mean prediction. -#' This should not be understood as sound statistical practice, but rather -#' as a practical way to create an example data set. -#' -#' The data was created using the script create-example-data.R in the inst/ -#' folder (or the top level folder in a compiled package). -#' -#' @format An object of class `forecast_binary` (see [as_forecast()]) with the -#' following columns: -#' \describe{ -#' \item{location}{the country for which a prediction was made} -#' \item{location_name}{name of the country for which a prediction was made} -#' \item{target_end_date}{the date for which a prediction was made} -#' \item{target_type}{the target to be predicted (cases or deaths)} -#' \item{observed}{A factor with observed values} -#' \item{forecast_date}{the date on which a prediction was made} -#' \item{model}{name of the model that generated the forecasts} -#' \item{horizon}{forecast horizon in weeks} -#' \item{predicted}{predicted value} -#' } -# nolint start -#' @source \url{https://github.com/european-modelling-hubs/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} -# nolint end -"example_binary" - - -#' Nominal example data -#' -#' A data set with predictions for COVID-19 cases and deaths submitted to the -#' European Forecast Hub. -#' -#' The data was created using the script create-example-data.R in the inst/ -#' folder (or the top level folder in a compiled package). -#' -#' @format An object of class `forecast_nominal` (see [as_forecast()]) with the -#' following columns: -#' \describe{ -#' \item{location}{the country for which a prediction was made} -#' \item{target_end_date}{the date for which a prediction was made} -#' \item{target_type}{the target to be predicted (cases or deaths)} -#' \item{observed}{Numeric: observed values} -#' \item{location_name}{name of the country for which a prediction was made} -#' \item{forecast_date}{the date on which a prediction was made} -#' \item{predicted_label}{outcome that a probabilty corresponds to} -#' \item{predicted}{predicted value} -#' \item{model}{name of the model that generated the forecasts} -#' \item{horizon}{forecast horizon in weeks} -#' } -# nolint start -#' @source \url{https://github.com/european-modelling-hubs/covid19-forecast-hub-europe/commit/a42867b1ea152c57e25b04f9faa26cfd4bfd8fa6/} -# nolint end -"example_nominal" diff --git a/R/default-scoring-rules.R b/R/default-scoring-rules.R deleted file mode 100644 index 0f9aae29f..000000000 --- a/R/default-scoring-rules.R +++ /dev/null @@ -1,260 +0,0 @@ -#' @title Select metrics from a list of functions -#' -#' @description -#' Helper function to return only the scoring rules selected by -#' the user from a list of possible functions. -#' -#' @param metrics A list of scoring functions. -#' @param select A character vector of scoring rules to select from the list. If -#' `select` is `NULL` (the default), all possible scoring rules are returned. -#' @param exclude A character vector of scoring rules to exclude from the list. -#' If `select` is not `NULL`, this argument is ignored. -#' @return A list of scoring functions. -#' @keywords handle-metrics -#' @importFrom checkmate assert_subset assert_list -#' @export -#' @examples -#' select_metrics( -#' metrics = get_metrics(example_binary), -#' select = "brier_score" -#' ) -#' select_metrics( -#' metrics = get_metrics(example_binary), -#' exclude = "log_score" -#' ) -select_metrics <- function(metrics, select = NULL, exclude = NULL) { - assert_character(x = c(select, exclude), null.ok = TRUE) - assert_list(metrics, names = "named") - allowed <- names(metrics) - - if (is.null(select) && is.null(exclude)) { - return(metrics) - } - if (is.null(select)) { - assert_subset(exclude, allowed) - select <- allowed[!allowed %in% exclude] - return(metrics[select]) - } - assert_subset(select, allowed) - return(metrics[select]) -} - -#' Get metrics -#' -#' @description -#' Generic function to to obtain default metrics availble for scoring or metrics -#' that were used for scoring. -#' -#' - If called on `forecast` object it returns a list of functions that can be -#' used for scoring. -#' - If called on a `scores` object (see [score()]), it returns a character vector -#' with the names of the metrics that were used for scoring. -#' -#' See the documentation for the actual methods in the `See Also` section below -#' for more details. Alternatively call `?get_metrics.` or -#' `?get_metrics.scores`. -#' -#' @param x A `forecast` or `scores` object. -#' @param ... Additional arguments passed to the method. -#' @details -#' See [as_forecast()] for more information on `forecast` objects and [score()] -#' for more information on `scores` objects. -#' -#' @family `get_metrics` functions -#' @keywords handle-metrics -#' @export -get_metrics <- function(x, ...) { - UseMethod("get_metrics") -} - - -#' Get default metrics for binary forecasts -#' -#' @description -#' For binary forecasts, the default scoring rules are: -#' - "brier_score" = [brier_score()] -#' - "log_score" = [logs_binary()] -#' @inheritSection illustration-input-metric-binary-point Input format -#' @param x A forecast object (a validated data.table with predicted and -#' observed values, see [as_forecast()]). -#' @param select A character vector of scoring rules to select from the list. If -#' `select` is `NULL` (the default), all possible scoring rules are returned. -#' @param exclude A character vector of scoring rules to exclude from the list. -#' If `select` is not `NULL`, this argument is ignored. -#' @param ... unused -#' @return A list of scoring functions. -#' @export -#' @family `get_metrics` functions -#' @keywords handle-metrics -#' @examples -#' get_metrics(example_binary) -#' get_metrics(example_binary, select = "brier_score") -#' get_metrics(example_binary, exclude = "log_score") -get_metrics.forecast_binary <- function(x, select = NULL, exclude = NULL, ...) { - all <- list( - brier_score = brier_score, - log_score = logs_binary - ) - select_metrics(all, select, exclude) -} - - -#' Get default metrics for nominal forecasts -#' @inheritParams get_metrics.forecast_binary -#' @description -#' For nominal forecasts, the default scoring rule is: -#' - "log_score" = [logs_nominal()] -#' @export -#' @family `get_metrics` functions -#' @keywords handle-metrics -#' @examples -#' get_metrics(example_nominal) -get_metrics.forecast_nominal <- function(x, select = NULL, exclude = NULL, ...) { - all <- list( - log_score = logs_nominal - ) - select_metrics(all, select, exclude) -} - - -#' Get default metrics for point forecasts -#' -#' @description -#' For point forecasts, the default scoring rules are: -#' - "ae_point" = [ae()][Metrics::ae()] -#' - "se_point" = [se()][Metrics::se()] -#' - "ape" = [ape()][Metrics::ape()] -#' -#' A note of caution: Every scoring rule for a point forecast -#' is implicitly minimised by a specific aspect of the predictive distribution -#' (see Gneiting, 2011). -#' -#' The mean squared error, for example, is only a meaningful scoring rule if -#' the forecaster actually reported the mean of their predictive distribution -#' as a point forecast. If the forecaster reported the median, then the mean -#' absolute error would be the appropriate scoring rule. If the scoring rule -#' and the predictive task do not align, the results will be misleading. -#' -#' Failure to respect this correspondence can lead to grossly misleading -#' results! Consider the example in the section below. -#' @inheritSection illustration-input-metric-binary-point Input format -#' @inheritParams get_metrics.forecast_binary -#' @export -#' @family `get_metrics` functions -#' @keywords handle-metrics -#' @examples -#' get_metrics(example_point, select = "ape") -#' -#' library(magrittr) -#' set.seed(123) -#' n <- 500 -#' observed <- rnorm(n, 5, 4)^2 -#' -#' predicted_mu <- mean(observed) -#' predicted_not_mu <- predicted_mu - rnorm(n, 10, 2) -#' -#' df <- data.frame( -#' model = rep(c("perfect", "bad"), each = n), -#' predicted = c(rep(predicted_mu, n), predicted_not_mu), -#' observed = rep(observed, 2), -#' id = rep(1:n, 2) -#' ) %>% -#' as_forecast_point() -#' score(df) %>% -#' summarise_scores() -#' @references -#' Making and Evaluating Point Forecasts, Gneiting, Tilmann, 2011, -#' Journal of the American Statistical Association. -get_metrics.forecast_point <- function(x, select = NULL, exclude = NULL, ...) { - all <- list( - ae_point = Metrics::ae, - se_point = Metrics::se, - ape = Metrics::ape - ) - select_metrics(all, select, exclude) -} - - -#' Get default metrics for sample-based forecasts -#' -#' @description -#' For sample-based forecasts, the default scoring rules are: -#' - "crps" = [crps_sample()] -#' - "overprediction" = [overprediction_sample()] -#' - "underprediction" = [underprediction_sample()] -#' - "dispersion" = [dispersion_sample()] -#' - "log_score" = [logs_sample()] -#' - "dss" = [dss_sample()] -#' - "mad" = [mad_sample()] -#' - "bias" = [bias_sample()] -#' - "ae_median" = [ae_median_sample()] -#' - "se_mean" = [se_mean_sample()] -#' @inheritSection illustration-input-metric-sample Input format -#' @inheritParams get_metrics.forecast_binary -#' @export -#' @family `get_metrics` functions -#' @keywords handle-metrics -#' @examples -#' get_metrics(example_sample_continuous, exclude = "mad") -get_metrics.forecast_sample <- function(x, select = NULL, exclude = NULL, ...) { - all <- list( - bias = bias_sample, - dss = dss_sample, - crps = crps_sample, - overprediction = overprediction_sample, - underprediction = underprediction_sample, - dispersion = dispersion_sample, - log_score = logs_sample, - mad = mad_sample, - ae_median = ae_median_sample, - se_mean = se_mean_sample - ) - select_metrics(all, select, exclude) -} - - -#' Get default metrics for quantile-based forecasts -#' -#' @description -#' For quantile-based forecasts, the default scoring rules are: -#' - "wis" = [wis()] -#' - "overprediction" = [overprediction_quantile()] -#' - "underprediction" = [underprediction_quantile()] -#' - "dispersion" = [dispersion_quantile()] -#' - "bias" = [bias_quantile()] -#' - "interval_coverage_50" = [interval_coverage()] -#' - "interval_coverage_90" = purrr::partial( -#' interval_coverage, interval_range = 90 -#' ) -#' - "ae_median" = [ae_median_quantile()] -#' -#' Note: The `interval_coverage_90` scoring rule is created by modifying -#' [interval_coverage()], making use of the function [purrr::partial()]. -#' This construct allows the function to deal with arbitrary arguments in `...`, -#' while making sure that only those that [interval_coverage()] can -#' accept get passed on to it. `interval_range = 90` is set in the function -#' definition, as passing an argument `interval_range = 90` to [score()] would -#' mean it would also get passed to `interval_coverage_50`. -#' @inheritSection illustration-input-metric-quantile Input format -#' @inheritParams get_metrics.forecast_binary -#' @export -#' @family `get_metrics` functions -#' @keywords handle-metrics -#' @importFrom purrr partial -#' @examples -#' get_metrics(example_quantile, select = "wis") -get_metrics.forecast_quantile <- function(x, select = NULL, exclude = NULL, ...) { - all <- list( - wis = wis, - overprediction = overprediction_quantile, - underprediction = underprediction_quantile, - dispersion = dispersion_quantile, - bias = bias_quantile, - interval_coverage_50 = interval_coverage, - interval_coverage_90 = purrr::partial( - interval_coverage, interval_range = 90 - ), - ae_median = ae_median_quantile - ) - select_metrics(all, select, exclude) -} diff --git a/R/forecast-unit.R b/R/forecast-unit.R new file mode 100644 index 000000000..b62c07b7e --- /dev/null +++ b/R/forecast-unit.R @@ -0,0 +1,64 @@ +#' @title Set unit of a single forecast manually +#' +#' @description +#' Helper function to set the unit of a single forecast (i.e. the +#' combination of columns that uniquely define a single forecast) manually. +#' This simple function keeps the columns specified in `forecast_unit` (plus +#' additional protected columns, e.g. for observed values, predictions or +#' quantile levels) and removes duplicate rows. `set_forecast_unit()` will +#' mainly be called when constructing a `forecast` object (see [as_forecast()]) +#' via the `forecast_unit` argument there. +#' +#' If not done explicitly, `scoringutils` attempts to determine the unit +#' of a single forecast automatically by simply assuming that all column names +#' are relevant to determine the forecast unit. This may lead to unexpected +#' behaviour, so setting the forecast unit explicitly can help make the code +#' easier to debug and easier to read. +#' +#' @inheritParams as_forecast +#' @param forecast_unit Character vector with the names of the columns that +#' uniquely identify a single forecast. +#' @importFrom cli cli_warn +#' @return A data.table with only those columns kept that are relevant to +#' scoring or denote the unit of a single forecast as specified by the user. +#' @importFrom data.table ':=' is.data.table copy +#' @importFrom checkmate assert_character assert_subset +#' @keywords as_forecast +#' @examples +#' library(magrittr) # pipe operator +#' example_quantile %>% +#' scoringutils:::set_forecast_unit( +#' c("location", "target_end_date", "target_type", "horizon", "model") +#' ) +set_forecast_unit <- function(data, forecast_unit) { + data <- ensure_data.table(data) + assert_subset(forecast_unit, names(data), empty.ok = FALSE) + keep_cols <- c(get_protected_columns(data), forecast_unit) + out <- unique(data[, .SD, .SDcols = keep_cols]) + return(out) +} + + +#' @title Get unit of a single forecast +#' @description +#' Helper function to get the unit of a single forecast, i.e. +#' the column names that define where a single forecast was made for. +#' This just takes all columns that are available in the data and subtracts +#' the columns that are protected, i.e. those returned by +#' [get_protected_columns()] as well as the names of the metrics that were +#' specified during scoring, if any. +#' @inheritParams as_forecast +#' @inheritSection forecast_types Forecast unit +#' @return +#' A character vector with the column names that define the unit of +#' a single forecast +#' @importFrom checkmate assert_data_frame +#' @export +#' @keywords diagnose-inputs +get_forecast_unit <- function(data) { + assert_data_frame(data) + protected_columns <- get_protected_columns(data) + protected_columns <- c(protected_columns, attr(data, "metrics")) + forecast_unit <- setdiff(colnames(data), unique(protected_columns)) + return(forecast_unit) +} diff --git a/R/get-correlations.R b/R/get-correlations.R new file mode 100644 index 000000000..67e8a6234 --- /dev/null +++ b/R/get-correlations.R @@ -0,0 +1,144 @@ +#' @title Calculate correlation between metrics +#' +#' @description +#' Calculate the correlation between different metrics for a data.frame of +#' scores as produced by [score()]. +#' +#' @param metrics A character vector with the metrics to show. If set to +#' `NULL` (default), all metrics present in `scores` will be shown. +#' @inheritParams get_pairwise_comparisons +#' @param ... Additional arguments to pass down to [cor()]. +#' @return +#' An object of class `scores` (a data.table with an additional +#' attribute `metrics` holding the names of the scores) with correlations +#' between different metrics +#' @importFrom data.table setDT +#' @importFrom stats cor na.omit +#' @importFrom cli cli_warn +#' @importFrom checkmate assert_subset +#' @export +#' @keywords scoring +#' @examples +#' library(magrittr) # pipe operator +#' +#' scores <- example_quantile %>% +#' as_forecast_quantile() %>% +#' score() +#' +#' get_correlations(scores) +get_correlations <- function(scores, + metrics = get_metrics.scores(scores), + ...) { + scores <- ensure_data.table(scores) + assert_subset(metrics, colnames(scores), empty.ok = FALSE) + df <- scores[, .SD, .SDcols = names(scores) %in% metrics] + + # define correlation matrix + cor_mat <- cor(as.matrix(df), ...) + + correlations <- new_scores( + as.data.frame((cor_mat)), + metrics = metrics, + keep.rownames = TRUE + ) + correlations <- copy(correlations)[, metric := rn][, rn := NULL] + + return(correlations[]) +} + + +#' @title Plot correlation between metrics +#' +#' @description +#' Plots a heatmap of correlations between different metrics. +#' +#' @param correlations A data.table of correlations between scores as produced +#' by [get_correlations()]. +#' @param digits A number indicating how many decimal places the correlations +#' should be rounded to. By default (`digits = NULL`) no rounding takes place. +#' @return +#' A ggplot object showing a coloured matrix of correlations between metrics. +#' @importFrom ggplot2 ggplot geom_tile geom_text aes scale_fill_gradient2 +#' element_text labs coord_cartesian theme element_blank +#' @importFrom data.table setDT melt +#' @importFrom checkmate assert_data_frame +#' @export +#' @return A ggplot object with a visualisation of correlations between metrics +#' @examples +#' library(magrittr) # pipe operator +#' scores <- example_quantile %>% +#' as_forecast_quantile %>% +#' score() +#' correlations <- scores %>% +#' summarise_scores() %>% +#' get_correlations() +#' plot_correlations(correlations, digits = 2) + +plot_correlations <- function(correlations, digits = NULL) { + + assert_data_frame(correlations) + metrics <- get_metrics.scores(correlations, error = TRUE) + + lower_triangle <- get_lower_tri(correlations[, .SD, .SDcols = metrics]) + + if (!is.null(digits)) { + lower_triangle <- round(lower_triangle, digits) + } + + + # check correlations is actually a matrix of correlations + col_present <- check_columns_present(correlations, "metric") + if (any(lower_triangle > 1, na.rm = TRUE) || !isTRUE(col_present)) { + #nolint start: keyword_quote_linter + cli_abort( + c( + "Found correlations > 1 or missing `metric` column.", + "i" = "Did you forget to call {.fn scoringutils::get_correlations}?" + ) + ) + #nolint end + } + + rownames(lower_triangle) <- colnames(lower_triangle) + + # get plot data.frame + plot_df <- data.table::as.data.table(lower_triangle)[, metric := metrics] + plot_df <- na.omit(data.table::melt(plot_df, id.vars = "metric")) + + # refactor levels according to the metrics + plot_df[, metric := factor(metric, levels = metrics)] + plot_df[, variable := factor(variable, rev(metrics))] + + plot <- ggplot(plot_df, aes( + x = variable, y = metric, + fill = value + )) + + geom_tile( + color = "white", + width = 0.97, height = 0.97 + ) + + geom_text(aes(y = metric, label = value)) + + scale_fill_gradient2( + low = "steelblue", mid = "white", + high = "salmon", + name = "Correlation", + breaks = c(-1, -0.5, 0, 0.5, 1) + ) + + theme_scoringutils() + + theme( + axis.text.x = element_text( + angle = 90, vjust = 1, + hjust = 1 + ) + ) + + labs(x = "", y = "") + + coord_cartesian(expand = FALSE) + return(plot) +} + + +# helper function to obtain lower triangle of matrix +get_lower_tri <- function(cormat) { + cormat[lower.tri(cormat)] <- NA + return(cormat) +} diff --git a/R/get-coverage.R b/R/get-coverage.R new file mode 100644 index 000000000..d16fc6239 --- /dev/null +++ b/R/get-coverage.R @@ -0,0 +1,239 @@ +#' @title Get quantile and interval coverage values for quantile-based forecasts +#' +#' @description +#' For a validated forecast object in a quantile-based format +#' (see [as_forecast()] for more information), this function computes: +#' - interval coverage of central prediction intervals +#' - quantile coverage for predictive quantiles +#' - the deviation between desired and actual coverage (both for interval and +#' quantile coverage) +#' +#' Coverage values are computed for a specific level of grouping, as specified +#' in the `by` argument. By default, coverage values are computed per model. +#' +#' **Interval coverage** +#' +#' Interval coverage for a given interval range is defined as the proportion of +#' observations that fall within the corresponding central prediction intervals. +#' Central prediction intervals are symmetric around the median and formed +#' by two quantiles that denote the lower and upper bound. For example, the 50% +#' central prediction interval is the interval between the 0.25 and 0.75 +#' quantiles of the predictive distribution. +#' +#' **Quantile coverage** +#' +#' Quantile coverage for a given quantile level is defined as the proportion of +#' observed values that are smaller than the corresponding predictive quantile. +#' For example, the 0.5 quantile coverage is the proportion of observed values +#' that are smaller than the 0.5 quantile of the predictive distribution. +#' Just as above, for a single observation and the quantile of a single +#' predictive distribution, the value will either be `TRUE` or `FALSE`. +#' +#' **Coverage deviation** +#' +#' The coverage deviation is the difference between the desired coverage +#' (can be either interval or quantile coverage) and the +#' actual coverage. For example, if the desired coverage is 90% and the actual +#' coverage is 80%, the coverage deviation is -0.1. +#' @return +#' A data.table with columns as specified in `by` and additional +#' columns for the coverage values described above +#' @inheritParams score +#' @param by character vector that denotes the level of grouping for which the +#' coverage values should be computed. By default (`"model"`), one coverage +#' value per model will be returned. +#' @return +#' a data.table with columns "interval_coverage", +#' "interval_coverage_deviation", "quantile_coverage", +#' "quantile_coverage_deviation" and the columns specified in `by`. +#' @importFrom data.table setcolorder +#' @importFrom checkmate assert_subset +#' @examples +#' library(magrittr) # pipe operator +#' example_quantile %>% +#' as_forecast_quantile() %>% +#' get_coverage(by = "model") +#' @export +#' @keywords scoring +#' @export +get_coverage <- function(forecast, by = "model") { + # input checks --------------------------------------------------------------- + forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE) + assert_subset(get_forecast_type(forecast), "quantile") + + # remove "quantile_level" and "interval_range" from `by` if present, as these + # are included anyway + by <- setdiff(by, c("quantile_level", "interval_range")) + assert_subset(by, names(forecast)) + + # convert to wide interval format and compute interval coverage -------------- + interval_forecast <- quantile_to_interval(forecast, format = "wide") + interval_forecast[, + interval_coverage := (observed <= upper) & (observed >= lower) + ][, c("lower", "upper", "observed") := NULL] + interval_forecast[, interval_coverage_deviation := + interval_coverage - interval_range / 100] + + # merge interval range data with original data ------------------------------- + # preparations + forecast[, interval_range := get_range_from_quantile(quantile_level)] + forecast_cols <- colnames(forecast) # store so we can reset column order later + forecast_unit <- get_forecast_unit(forecast) + + forecast <- merge(forecast, interval_forecast, + by = unique(c(forecast_unit, "interval_range"))) + + # compute quantile coverage and deviation ------------------------------------ + forecast[, quantile_coverage := observed <= predicted] + forecast[, quantile_coverage_deviation := quantile_coverage - quantile_level] + + # summarise coverage values according to `by` and cleanup -------------------- + # reset column order + new_metrics <- c("interval_coverage", "interval_coverage_deviation", + "quantile_coverage", "quantile_coverage_deviation") + setcolorder(forecast, unique(c(forecast_cols, "interval_range", new_metrics))) + # remove forecast class and convert to regular data.table + forecast <- as.data.table(forecast) + by <- unique(c(by, "quantile_level", "interval_range")) + # summarise + forecast <- forecast[, lapply(.SD, mean), by = by, .SDcols = new_metrics] + return(forecast[]) +} + + +#' @title Plot interval coverage +#' +#' @description +#' Plot interval coverage values (see [get_coverage()] for more information). +#' +#' @param coverage A data frame of coverage values as produced by +#' [get_coverage()]. +#' @param colour According to which variable shall the graphs be coloured? +#' Default is "model". +#' @return ggplot object with a plot of interval coverage +#' @importFrom ggplot2 ggplot scale_colour_manual scale_fill_manual .data +#' facet_wrap facet_grid geom_polygon geom_line +#' @importFrom checkmate assert_subset +#' @importFrom data.table dcast +#' @export +#' @examples +#' \dontshow{ +#' data.table::setDTthreads(2) # restricts number of cores used on CRAN +#' } +#' example <- as_forecast_quantile(example_quantile) +#' coverage <- get_coverage(example, by = "model") +#' plot_interval_coverage(coverage) +plot_interval_coverage <- function(coverage, + colour = "model") { + coverage <- ensure_data.table(coverage) + assert_subset(colour, names(coverage)) + + # in case quantile columns are present, remove them and then take unique + # values. This doesn't visually affect the plot, but prevents lines from being + # drawn twice. + del <- c("quantile_level", "quantile_coverage", "quantile_coverage_deviation") + suppressWarnings(coverage[, eval(del) := NULL]) + coverage <- unique(coverage) + + ## overall model calibration - empirical interval coverage + p1 <- ggplot(coverage, aes( + x = interval_range, + colour = .data[[colour]] + )) + + geom_polygon( + data = data.frame( + x = c(0, 0, 100), + y = c(0, 100, 100), + g = c("o", "o", "o"), + stringsAsFactors = TRUE + ), + aes( + x = x, y = y, group = g, + fill = g + ), + alpha = 0.05, + colour = "white", + fill = "olivedrab3" + ) + + geom_line( + aes(y = interval_range), + colour = "grey", + linetype = "dashed" + ) + + geom_line(aes(y = interval_coverage * 100)) + + theme_scoringutils() + + ylab("% Obs inside interval") + + xlab("Nominal interval coverage") + + coord_cartesian(expand = FALSE) + + return(p1) +} + + +#' @title Plot quantile coverage +#' +#' @description +#' Plot quantile coverage values (see [get_coverage()] for more information). +#' +#' @inheritParams plot_interval_coverage +#' @param colour String, according to which variable shall the graphs be +#' coloured? Default is "model". +#' @return A ggplot object with a plot of interval coverage +#' @importFrom ggplot2 ggplot scale_colour_manual scale_fill_manual .data aes +#' scale_y_continuous geom_line +#' @importFrom checkmate assert_subset assert_data_frame +#' @importFrom data.table dcast +#' @export +#' @examples +#' example <- as_forecast_quantile(example_quantile) +#' coverage <- get_coverage(example, by = "model") +#' plot_quantile_coverage(coverage) + +plot_quantile_coverage <- function(coverage, + colour = "model") { + coverage <- assert_data_frame(coverage) + assert_subset(colour, names(coverage)) + + p2 <- ggplot( + data = coverage, + aes(x = quantile_level, colour = .data[[colour]]) + ) + + geom_polygon( + data = data.frame( + x = c( + 0, 0.5, 0.5, + 0.5, 0.5, 1 + ), + y = c( + 0, 0, 0.5, + 0.5, 1, 1 + ), + g = c("o", "o", "o"), + stringsAsFactors = TRUE + ), + aes( + x = x, y = y, group = g, + fill = g + ), + alpha = 0.05, + colour = "white", + fill = "olivedrab3" + ) + + geom_line( + aes(y = quantile_level), + colour = "grey", + linetype = "dashed" + ) + + geom_line(aes(y = quantile_coverage)) + + theme_scoringutils() + + xlab("Quantile level") + + ylab("% Obs below quantile level") + + scale_y_continuous( + labels = function(x) { + paste(100 * x) + } + ) + + coord_cartesian(expand = FALSE) + + return(p2) +} diff --git a/R/get-duplicate-forecasts.R b/R/get-duplicate-forecasts.R new file mode 100644 index 000000000..e3c426464 --- /dev/null +++ b/R/get-duplicate-forecasts.R @@ -0,0 +1,72 @@ +#' @title Find duplicate forecasts +#' +#' @description +#' Internal helper function to identify duplicate forecasts, i.e. +#' instances where there is more than one forecast for the same prediction +#' target. +#' +#' @inheritParams as_forecast +#' @param counts Should the output show the number of duplicates per forecast +#' unit instead of the individual duplicated rows? Default is `FALSE`. +#' @return A data.frame with all rows for which a duplicate forecast was found +#' @export +#' @importFrom checkmate assert_data_frame assert_subset +#' @importFrom data.table setorderv +#' @keywords diagnose-inputs +#' @examples +#' example <- rbind(example_quantile, example_quantile[1000:1010]) +#' get_duplicate_forecasts(example) +get_duplicate_forecasts <- function( + data, + forecast_unit = NULL, + counts = FALSE +) { + assert_data_frame(data) + data <- ensure_data.table(data) + + if (!is.null(forecast_unit)) { + data <- set_forecast_unit(data, forecast_unit) + } + forecast_unit <- get_forecast_unit(data) + available_type <- c("sample_id", "quantile_level", "predicted_label") %in% colnames(data) + type <- c("sample_id", "quantile_level", "predicted_label")[available_type] + data <- as.data.table(data) + data[, scoringutils_InternalDuplicateCheck := .N, by = c(forecast_unit, type)] + out <- data[scoringutils_InternalDuplicateCheck > 1] + + col <- colnames(data)[ + colnames(data) %in% c("sample_id", "quantile_level", "predicted_label") + ] + setorderv(out, cols = c(forecast_unit, col, "predicted")) + out[, scoringutils_InternalDuplicateCheck := NULL] + + if (counts) { + out <- out[, .(n_duplicates = .N), by = c(get_forecast_unit(out))] + } + + return(out[]) +} + + +#' Check that there are no duplicate forecasts +#' +#' @description +#' Runs [get_duplicate_forecasts()] and returns a message if an issue is +#' encountered +#' @inheritParams get_duplicate_forecasts +#' @inherit document_check_functions return +#' @keywords internal_input_check +check_duplicates <- function(data) { + check_duplicates <- get_duplicate_forecasts(data) + + if (nrow(check_duplicates) > 0) { + msg <- paste0( + "There are instances with more than one forecast for the same target. ", + "This can't be right and needs to be resolved. Maybe you need to ", + "check the unit of a single forecast and add missing columns? Use ", + "the function get_duplicate_forecasts() to identify duplicate rows" + ) + return(msg) + } + return(TRUE) +} diff --git a/R/get-forecast-counts.R b/R/get-forecast-counts.R new file mode 100644 index 000000000..6582743e1 --- /dev/null +++ b/R/get-forecast-counts.R @@ -0,0 +1,151 @@ +#' @title Count number of available forecasts +#' +#' @description +#' Given a data set with forecasts, this function counts the number of +#' available forecasts. +#' The level of grouping can be specified using the `by` argument (e.g. to +#' count the number of forecasts per model, or the number of forecasts per +#' model and location). +#' This is useful to determine whether there are any missing forecasts. +#' +#' @param by character vector or `NULL` (the default) that denotes the +#' categories over which the number of forecasts should be counted. +#' By default this will be the unit of a single forecast (i.e. +#' all available columns (apart from a few "protected" columns such as +#' 'predicted' and 'observed') plus "quantile_level" or "sample_id" where +#' present). +#' +#' @param collapse character vector (default: `c("quantile_level", "sample_id"`) +#' with names of categories for which the number of rows should be collapsed +#' to one when counting. For example, a single forecast is usually represented +#' by a set of several quantiles or samples and collapsing these to one makes +#' sure that a single forecast only gets counted once. Setting +#' `collapse = c()` would mean that all quantiles / samples would be counted +#' as individual forecasts. +#' +#' @return A data.table with columns as specified in `by` and an additional +#' column "count" with the number of forecasts. +#' +#' @inheritParams score +#' @importFrom data.table .I .N nafill +#' @export +#' @keywords gain-insights +#' @examples +#' \dontshow{ +#' data.table::setDTthreads(2) # restricts number of cores used on CRAN +#' } +#' +#' library(magrittr) # pipe operator +#' example_quantile %>% +#' as_forecast_quantile() %>% +#' get_forecast_counts(by = c("model", "target_type")) +get_forecast_counts <- function(forecast, + by = get_forecast_unit(forecast), + collapse = c("quantile_level", "sample_id")) { + forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE) + forecast_unit <- get_forecast_unit(forecast) + assert_subset(by, names(forecast), empty.ok = FALSE) + forecast <- as.data.table(forecast) + + # collapse several rows to 1, e.g. treat a set of 10 quantiles as one, + # because they all belong to one single forecast that should be counted once + collapse_by <- setdiff( + c(forecast_unit, "quantile_level", "sample_id"), + collapse + ) + # filter "quantile_level", "sample" if in `collapse_by`, but not the forecast + collapse_by <- intersect(collapse_by, names(forecast)) + + forecast <- forecast[forecast[, .I[1], by = collapse_by]$V1] + + # count number of rows = number of forecasts + out <- forecast[, .(count = .N), by = by] + + # make sure that all combinations in "by" are included in the output (with + # count = 0). To achieve that, take unique values in `forecast` and expand grid + col_vecs <- unclass(out) + col_vecs$count <- NULL + col_vecs <- lapply(col_vecs, unique) + out_empty <- expand.grid(col_vecs, stringsAsFactors = FALSE) + + out <- merge(out, out_empty, by = by, all.y = TRUE) + out[, count := nafill(count, fill = 0)] + + return(out[]) +} + + +#' @title Visualise the number of available forecasts +#' +#' @description +#' Visualise Where Forecasts Are Available. +#' @param forecast_counts A data.table (or similar) with a column `count` +#' holding forecast counts, as produced by [get_forecast_counts()]. +#' @param x Character vector of length one that denotes the name of the column +#' to appear on the x-axis of the plot. +#' @param y Character vector of length one that denotes the name of the column +#' to appear on the y-axis of the plot. Default is "model". +#' @param x_as_factor Logical (default is `TRUE`). Whether or not to convert +#' the variable on the x-axis to a factor. This has an effect e.g. if dates +#' are shown on the x-axis. +#' @param show_counts Logical (default is `TRUE`) that indicates whether +#' or not to show the actual count numbers on the plot. +#' @return A ggplot object with a plot of forecast counts +#' @importFrom ggplot2 ggplot scale_colour_manual scale_fill_manual +#' geom_tile scale_fill_gradient .data +#' @importFrom data.table dcast .I .N +#' @importFrom checkmate assert_subset assert_logical +#' @export +#' @examples +#' library(ggplot2) +#' library(magrittr) # pipe operator +#' forecast_counts <- example_quantile %>% +#' as_forecast_quantile %>% +#' get_forecast_counts(by = c("model", "target_type", "target_end_date")) +#' plot_forecast_counts( +#' forecast_counts, x = "target_end_date", show_counts = FALSE +#' ) + +#' facet_wrap("target_type") + +plot_forecast_counts <- function(forecast_counts, + x, + y = "model", + x_as_factor = TRUE, + show_counts = TRUE) { + + forecast_counts <- ensure_data.table(forecast_counts) + assert_subset(y, colnames(forecast_counts)) + assert_subset(x, colnames(forecast_counts)) + assert_logical(x_as_factor, len = 1) + assert_logical(show_counts, len = 1) + + if (x_as_factor) { + forecast_counts[, eval(x) := as.factor(get(x))] + } + + setnames(forecast_counts, old = "count", new = "Count") + + plot <- ggplot( + forecast_counts, + aes(y = .data[[y]], x = .data[[x]]) + ) + + geom_tile(aes(fill = `Count`), + width = 0.97, height = 0.97) + + scale_fill_gradient( + low = "grey95", high = "steelblue", + na.value = "lightgrey" + ) + + theme_scoringutils() + + theme( + axis.text.x = element_text( + angle = 90, vjust = 1, + hjust = 1 + ) + ) + + theme(panel.spacing = unit(2, "lines")) + if (show_counts) { + plot <- plot + + geom_text(aes(label = `Count`)) + } + return(plot) +} diff --git a/R/get-forecast-type.R b/R/get-forecast-type.R new file mode 100644 index 000000000..3f6fe1abd --- /dev/null +++ b/R/get-forecast-type.R @@ -0,0 +1,78 @@ +#' Get forecast type from forecast object +#' @inheritParams score +#' @return +#' Character vector of length one with the forecast type. +#' @keywords internal_input_check +get_forecast_type <- function(forecast) { + classname <- class(forecast)[1] + if (grepl("forecast_", classname, fixed = TRUE)) { + type <- gsub("forecast_", "", classname, fixed = TRUE) + return(type) + } else { + cli_abort( + "Input is not a valid forecast object + (it's first class should begin with `forecast_`)." + ) + } +} + + +#' Assert that forecast type is as expected +#' @param data A forecast object (see [as_forecast()]). +#' @param actual The actual forecast type of the data +#' @param desired The desired forecast type of the data +#' @inherit document_assert_functions return +#' @importFrom cli cli_abort +#' @importFrom checkmate assert_character +#' @keywords internal_input_check +assert_forecast_type <- function(data, + actual = get_forecast_type(data), + desired = NULL) { + assert_character(desired, null.ok = TRUE) + if (!is.null(desired) && desired != actual) { + #nolint start: object_usage_linter keyword_quote_linter + cli_abort( + c( + "!" = "Forecast type determined by scoringutils based on input: + {.val {actual}}.", + "i" = "Desired forecast type: {.val {desired}}." + ) + ) + #nolint end + } + return(invisible(NULL)) +} + + +#' @title Get type of a vector or matrix of observed values or predictions +#' +#' @description +#' Internal helper function to get the type of a vector (usually +#' of observed or predicted values). The function checks whether the input is +#' a factor, or else whether it is integer (or can be coerced to integer) or +#' whether it's continuous. +#' @param x Input the type should be determined for. +#' @importFrom cli cli_abort +#' @return +#' Character vector of length one with either "classification", +#' "integer", or "continuous". +#' @keywords internal_input_check +get_type <- function(x) { + if (is.factor(x)) { + return("classification") + } + assert_numeric(as.vector(x)) + if (all(is.na(as.vector(x)))) { + cli_abort("Can't get type: all values of are {.val NA}.") + } + if (is.integer(x)) { + return("integer") + } + if ( + isTRUE(all.equal(as.vector(x), as.integer(x))) && !all(is.na(as.integer(x))) + ) { + return("integer") + } else { + return("continuous") + } +} diff --git a/R/get-pit.R b/R/get-pit.R new file mode 100644 index 000000000..a93a68e45 --- /dev/null +++ b/R/get-pit.R @@ -0,0 +1,177 @@ +#' @title Probability integral transformation (data.frame version) +#' +#' @description +#' Compute the Probability Integral Transformation (PIT) for +#' validated forecast objects. +#' +#' @inherit score params +#' @param by Character vector with the columns according to which the +#' PIT values shall be grouped. If you e.g. have the columns 'model' and +#' 'location' in the input data and want to have a PIT histogram for +#' every model and location, specify `by = c("model", "location")`. +#' @inheritParams pit_sample +#' @return A data.table with PIT values according to the grouping specified in +#' `by`. +#' @examples +#' example <- as_forecast_sample(example_sample_continuous) +#' result <- get_pit(example, by = "model") +#' plot_pit(result) +#' +#' # example with quantile data +#' example <- as_forecast_quantile(example_quantile) +#' result <- get_pit(example, by = "model") +#' plot_pit(result) +#' @export +#' @keywords scoring +#' @references +#' Sebastian Funk, Anton Camacho, Adam J. Kucharski, Rachel Lowe, +#' Rosalind M. Eggo, W. John Edmunds (2019) Assessing the performance of +#' real-time epidemic forecasts: A case study of Ebola in the Western Area +#' region of Sierra Leone, 2014-15, \doi{10.1371/journal.pcbi.1006785} +get_pit <- function(forecast, by, ...) { + UseMethod("get_pit") +} + + +#' @rdname get_pit +#' @importFrom cli cli_abort +#' @export +get_pit.default <- function(forecast, by, ...) { + cli_abort(c( + "!" = "The input needs to be a valid forecast object represented as quantiles or samples." # nolint + )) +} + + +#' @title PIT histogram +#' +#' @description +#' Make a simple histogram of the probability integral transformed values to +#' visually check whether a uniform distribution seems likely. +#' +#' @param pit Either a vector with the PIT values, or a data.table as +#' produced by [get_pit()]. +#' @param num_bins The number of bins in the PIT histogram, default is "auto". +#' When `num_bins == "auto"`, [plot_pit()] will either display 10 bins, or it +#' will display a bin for each available quantile in case you passed in data in +#' a quantile-based format. +#' You can control the number of bins by supplying a number. This is fine for +#' sample-based pit histograms, but may fail for quantile-based formats. In this +#' case it is preferred to supply explicit breaks points using the `breaks` +#' argument. +#' @param breaks Numeric vector with the break points for the bins in the +#' PIT histogram. This is preferred when creating a PIT histogram based on +#' quantile-based data. Default is `NULL` and breaks will be determined by +#' `num_bins`. If `breaks` is used, `num_bins` will be ignored. +#' @importFrom stats as.formula +#' @importFrom ggplot2 geom_col +#' @importFrom stats density +#' @return A ggplot object with a histogram of PIT values +#' @examples +#' \dontshow{ +#' data.table::setDTthreads(2) # restricts number of cores used on CRAN +#' } +#' library(magrittr) # pipe operator +#' +#' # PIT histogram in vector based format +#' observed <- rnorm(30, mean = 1:30) +#' predicted <- replicate(200, rnorm(n = 30, mean = 1:30)) +#' pit <- pit_sample(observed, predicted) +#' plot_pit(pit) +#' +#' # quantile-based pit +#' pit <- example_quantile %>% +#' as_forecast_quantile() %>% +#' get_pit(by = "model") +#' plot_pit(pit, breaks = seq(0.1, 1, 0.1)) +#' +#' # sample-based pit +#' pit <- example_sample_discrete %>% +#' as_forecast_sample %>% +#' get_pit(by = "model") +#' plot_pit(pit) +#' @importFrom ggplot2 ggplot aes xlab ylab geom_histogram stat theme_light after_stat +#' @importFrom checkmate assert check_set_equal check_number +#' @export +plot_pit <- function(pit, + num_bins = "auto", + breaks = NULL) { + assert( + check_set_equal(num_bins, "auto"), + check_number(num_bins, lower = 1) + ) + assert_numeric(breaks, lower = 0, upper = 1, null.ok = TRUE) + + # vector-format is always sample-based, for data.frames there are two options + if ("quantile_level" %in% names(pit)) { + type <- "quantile-based" + } else { + type <- "sample-based" + } + + # use breaks if explicitly given, otherwise assign based on number of bins + if (!is.null(breaks)) { + plot_quantiles <- unique(c(0, breaks, 1)) + } else if (is.null(num_bins) || num_bins == "auto") { + # automatically set number of bins + if (type == "sample-based") { + num_bins <- 10 + width <- 1 / num_bins + plot_quantiles <- seq(0, 1, width) + } + if (type == "quantile-based") { + plot_quantiles <- unique(c(0, pit$quantile_level, 1)) + } + } else { + # if num_bins is explicitly given + width <- 1 / num_bins + plot_quantiles <- seq(0, 1, width) + } + + # function for data.frames + if (is.data.frame(pit)) { + facet_cols <- get_forecast_unit(pit) + formula <- as.formula(paste("~", paste(facet_cols, collapse = "+"))) + + # quantile version + if (type == "quantile-based") { + hist <- ggplot( + data = pit[quantile_level %in% plot_quantiles], + aes(x = quantile_level, y = pit_value) + ) + + geom_col(position = "dodge", colour = "grey") + + facet_wrap(formula) + } + + if (type == "sample-based") { + hist <- ggplot( + data = pit, + aes(x = pit_value) + ) + + geom_histogram( + aes(y = after_stat(width * density)), + breaks = plot_quantiles, + colour = "grey" + ) + + facet_wrap(formula) + } + } else { + # non data.frame version + hist <- ggplot( + data = data.frame(x = pit, stringsAsFactors = TRUE), + aes(x = x) + ) + + geom_histogram( + aes(y = after_stat(width * density)), + breaks = plot_quantiles, + colour = "grey" + ) + } + + hist <- hist + + xlab("PIT") + + ylab("Frequency") + + theme_scoringutils() + + return(hist) +} diff --git a/R/get-protected-columns.R b/R/get-protected-columns.R new file mode 100644 index 000000000..e617672d5 --- /dev/null +++ b/R/get-protected-columns.R @@ -0,0 +1,37 @@ +#' @title Get protected columns from data +#' +#' @description Helper function to get the names of all columns in a data frame +#' that are protected columns. +#' +#' @inheritParams as_forecast +#' +#' @return +#' A character vector with the names of protected columns in the data. +#' If data is `NULL` (default) then it returns a list of all columns that are +#' protected in scoringutils. +#' +#' @keywords internal +get_protected_columns <- function(data = NULL) { + + protected_columns <- c( + "predicted", "observed", "sample_id", "quantile_level", "upper", "lower", + "pit_value", "interval_range", "boundary", "predicted_label", + "interval_coverage", "interval_coverage_deviation", + "quantile_coverage", "quantile_coverage_deviation", + grep("_relative_skill$", names(data), value = TRUE), + grep("coverage_", names(data), fixed = TRUE, value = TRUE) + ) + + if (is.null(data)) { + return(protected_columns) + } + + # only return protected columns that are present + datacols <- colnames(data) + protected_columns <- intersect( + datacols, + protected_columns + ) + + return(protected_columns) +} diff --git a/R/get_-functions.R b/R/get_-functions.R deleted file mode 100644 index 7d9b1f005..000000000 --- a/R/get_-functions.R +++ /dev/null @@ -1,446 +0,0 @@ -# Functions that help to obtain information about the data - -#' Get forecast type from forecast object -#' @inheritParams score -#' @return -#' Character vector of length one with the forecast type. -#' @keywords internal_input_check -get_forecast_type <- function(forecast) { - classname <- class(forecast)[1] - if (grepl("forecast_", classname, fixed = TRUE)) { - type <- gsub("forecast_", "", classname, fixed = TRUE) - return(type) - } else { - cli_abort( - "Input is not a valid forecast object - (it's first class should begin with `forecast_`)." - ) - } -} - - -#' Assert that forecast type is as expected -#' @param data A forecast object (see [as_forecast()]). -#' @param actual The actual forecast type of the data -#' @param desired The desired forecast type of the data -#' @inherit document_assert_functions return -#' @importFrom cli cli_abort -#' @importFrom checkmate assert_character -#' @keywords internal_input_check -assert_forecast_type <- function(data, - actual = get_forecast_type(data), - desired = NULL) { - assert_character(desired, null.ok = TRUE) - if (!is.null(desired) && desired != actual) { - #nolint start: object_usage_linter keyword_quote_linter - cli_abort( - c( - "!" = "Forecast type determined by scoringutils based on input: - {.val {actual}}.", - "i" = "Desired forecast type: {.val {desired}}." - ) - ) - #nolint end - } - return(invisible(NULL)) -} - - -#' @title Get type of a vector or matrix of observed values or predictions -#' -#' @description -#' Internal helper function to get the type of a vector (usually -#' of observed or predicted values). The function checks whether the input is -#' a factor, or else whether it is integer (or can be coerced to integer) or -#' whether it's continuous. -#' @param x Input the type should be determined for. -#' @importFrom cli cli_abort -#' @return -#' Character vector of length one with either "classification", -#' "integer", or "continuous". -#' @keywords internal_input_check -get_type <- function(x) { - if (is.factor(x)) { - return("classification") - } - assert_numeric(as.vector(x)) - if (all(is.na(as.vector(x)))) { - cli_abort("Can't get type: all values of are {.val NA}.") - } - if (is.integer(x)) { - return("integer") - } - if ( - isTRUE(all.equal(as.vector(x), as.integer(x))) && !all(is.na(as.integer(x))) - ) { - return("integer") - } else { - return("continuous") - } -} - - -#' @title Get names of the metrics that were used for scoring -#' @description -#' When applying a scoring rule via [score()], the names of the scoring rules -#' become column names of the -#' resulting data.table. In addition, an attribute `metrics` will be -#' added to the output, holding the names of the scores as a vector. -#' -#' This is done so that functions like [get_forecast_unit()] or -#' [summarise_scores()] can still identify which columns are part of the -#' forecast unit and which hold a score. -#' -#' `get_metrics()` accesses and returns the `metrics` attribute. If there is no -#' attribute, the function will return `NULL` (or, if `error = TRUE` will -#' produce an error instead). In addition, it checks the column names of the -#' input for consistency with the data stored in the `metrics` attribute. -#' -#' **Handling a missing or inconsistent `metrics` attribute**: -#' -#' If the metrics attribute is missing or is not consistent with the column -#' names of the data.table, you can either -#' -#' - run [score()] again, specifying names for the scoring rules manually, or -#' - add/update the attribute manually using -#' `attr(scores, "metrics") <- c("names", "of", "your", "scores")` (the -#' order does not matter). -#' -#' @param x A `scores` object, (a data.table with an attribute `metrics` as -#' produced by [score()]). -#' @param error Throw an error if there is no attribute called `metrics`? -#' Default is FALSE. -#' @param ... unused -#' @importFrom cli cli_abort cli_warn -#' @importFrom checkmate assert_data_frame -#' @return -#' Character vector with the names of the scoring rules that were used -#' for scoring. -#' @keywords handle-metrics -#' @family `get_metrics` functions -#' @export -get_metrics.scores <- function(x, error = FALSE, ...) { - assert_data_frame(x) - metrics <- attr(x, "metrics") - if (error && is.null(metrics)) { - #nolint start: keyword_quote_linter - cli_abort( - c( - "!" = "Input needs an attribute `metrics` with the names of the - scoring rules that were used for scoring.", - "i" = "See `?get_metrics` for further information." - ) - ) - #nolint end - } - - if (!all(metrics %in% names(x))) { - #nolint start: keyword_quote_linter object_usage_linter - missing <- setdiff(metrics, names(x)) - cli_warn( - c( - "!" = "The following scores have been previously computed, but are no - longer column names of the data: {.val {missing}}", - "i" = "See {.code ?get_metrics} for further information." - ) - ) - #nolint end - } - - return(metrics) -} - - -#' @title Get unit of a single forecast -#' @description -#' Helper function to get the unit of a single forecast, i.e. -#' the column names that define where a single forecast was made for. -#' This just takes all columns that are available in the data and subtracts -#' the columns that are protected, i.e. those returned by -#' [get_protected_columns()] as well as the names of the metrics that were -#' specified during scoring, if any. -#' @inheritParams as_forecast -#' @inheritSection forecast_types Forecast unit -#' @return -#' A character vector with the column names that define the unit of -#' a single forecast -#' @importFrom checkmate assert_data_frame -#' @export -#' @keywords diagnose-inputs -get_forecast_unit <- function(data) { - assert_data_frame(data) - protected_columns <- get_protected_columns(data) - protected_columns <- c(protected_columns, attr(data, "metrics")) - forecast_unit <- setdiff(colnames(data), unique(protected_columns)) - return(forecast_unit) -} - - -#' @title Get protected columns from data -#' -#' @description Helper function to get the names of all columns in a data frame -#' that are protected columns. -#' -#' @inheritParams as_forecast -#' -#' @return -#' A character vector with the names of protected columns in the data. -#' If data is `NULL` (default) then it returns a list of all columns that are -#' protected in scoringutils. -#' -#' @keywords internal -get_protected_columns <- function(data = NULL) { - - protected_columns <- c( - "predicted", "observed", "sample_id", "quantile_level", "upper", "lower", - "pit_value", "interval_range", "boundary", "predicted_label", - "interval_coverage", "interval_coverage_deviation", - "quantile_coverage", "quantile_coverage_deviation", - grep("_relative_skill$", names(data), value = TRUE), - grep("coverage_", names(data), fixed = TRUE, value = TRUE) - ) - - if (is.null(data)) { - return(protected_columns) - } - - # only return protected columns that are present - datacols <- colnames(data) - protected_columns <- intersect( - datacols, - protected_columns - ) - - return(protected_columns) -} - - -#' @title Find duplicate forecasts -#' -#' @description -#' Internal helper function to identify duplicate forecasts, i.e. -#' instances where there is more than one forecast for the same prediction -#' target. -#' -#' @inheritParams as_forecast -#' @param counts Should the output show the number of duplicates per forecast -#' unit instead of the individual duplicated rows? Default is `FALSE`. -#' @return A data.frame with all rows for which a duplicate forecast was found -#' @export -#' @importFrom checkmate assert_data_frame assert_subset -#' @importFrom data.table setorderv -#' @keywords diagnose-inputs -#' @examples -#' example <- rbind(example_quantile, example_quantile[1000:1010]) -#' get_duplicate_forecasts(example) - -get_duplicate_forecasts <- function( - data, - forecast_unit = NULL, - counts = FALSE -) { - assert_data_frame(data) - data <- ensure_data.table(data) - - if (!is.null(forecast_unit)) { - data <- set_forecast_unit(data, forecast_unit) - } - forecast_unit <- get_forecast_unit(data) - available_type <- c("sample_id", "quantile_level", "predicted_label") %in% colnames(data) - type <- c("sample_id", "quantile_level", "predicted_label")[available_type] - data <- as.data.table(data) - data[, scoringutils_InternalDuplicateCheck := .N, by = c(forecast_unit, type)] - out <- data[scoringutils_InternalDuplicateCheck > 1] - - col <- colnames(data)[ - colnames(data) %in% c("sample_id", "quantile_level", "predicted_label") - ] - setorderv(out, cols = c(forecast_unit, col, "predicted")) - out[, scoringutils_InternalDuplicateCheck := NULL] - - if (counts) { - out <- out[, .(n_duplicates = .N), by = c(get_forecast_unit(out))] - } - - return(out[]) -} - - -#' @title Get quantile and interval coverage values for quantile-based forecasts -#' -#' @description -#' For a validated forecast object in a quantile-based format -#' (see [as_forecast()] for more information), this function computes: -#' - interval coverage of central prediction intervals -#' - quantile coverage for predictive quantiles -#' - the deviation between desired and actual coverage (both for interval and -#' quantile coverage) -#' -#' Coverage values are computed for a specific level of grouping, as specified -#' in the `by` argument. By default, coverage values are computed per model. -#' -#' **Interval coverage** -#' -#' Interval coverage for a given interval range is defined as the proportion of -#' observations that fall within the corresponding central prediction intervals. -#' Central prediction intervals are symmetric around the median and formed -#' by two quantiles that denote the lower and upper bound. For example, the 50% -#' central prediction interval is the interval between the 0.25 and 0.75 -#' quantiles of the predictive distribution. -#' -#' **Quantile coverage** -#' -#' Quantile coverage for a given quantile level is defined as the proportion of -#' observed values that are smaller than the corresponding predictive quantile. -#' For example, the 0.5 quantile coverage is the proportion of observed values -#' that are smaller than the 0.5 quantile of the predictive distribution. -#' Just as above, for a single observation and the quantile of a single -#' predictive distribution, the value will either be `TRUE` or `FALSE`. -#' -#' **Coverage deviation** -#' -#' The coverage deviation is the difference between the desired coverage -#' (can be either interval or quantile coverage) and the -#' actual coverage. For example, if the desired coverage is 90% and the actual -#' coverage is 80%, the coverage deviation is -0.1. -#' @return -#' A data.table with columns as specified in `by` and additional -#' columns for the coverage values described above -#' @inheritParams score -#' @param by character vector that denotes the level of grouping for which the -#' coverage values should be computed. By default (`"model"`), one coverage -#' value per model will be returned. -#' @return -#' a data.table with columns "interval_coverage", -#' "interval_coverage_deviation", "quantile_coverage", -#' "quantile_coverage_deviation" and the columns specified in `by`. -#' @importFrom data.table setcolorder -#' @importFrom checkmate assert_subset -#' @examples -#' library(magrittr) # pipe operator -#' example_quantile %>% -#' as_forecast_quantile() %>% -#' get_coverage(by = "model") -#' @export -#' @keywords scoring -#' @export -get_coverage <- function(forecast, by = "model") { - # input checks --------------------------------------------------------------- - forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE) - assert_subset(get_forecast_type(forecast), "quantile") - - # remove "quantile_level" and "interval_range" from `by` if present, as these - # are included anyway - by <- setdiff(by, c("quantile_level", "interval_range")) - assert_subset(by, names(forecast)) - - # convert to wide interval format and compute interval coverage -------------- - interval_forecast <- quantile_to_interval(forecast, format = "wide") - interval_forecast[, - interval_coverage := (observed <= upper) & (observed >= lower) - ][, c("lower", "upper", "observed") := NULL] - interval_forecast[, interval_coverage_deviation := - interval_coverage - interval_range / 100] - - # merge interval range data with original data ------------------------------- - # preparations - forecast[, interval_range := get_range_from_quantile(quantile_level)] - forecast_cols <- colnames(forecast) # store so we can reset column order later - forecast_unit <- get_forecast_unit(forecast) - - forecast <- merge(forecast, interval_forecast, - by = unique(c(forecast_unit, "interval_range"))) - - # compute quantile coverage and deviation ------------------------------------ - forecast[, quantile_coverage := observed <= predicted] - forecast[, quantile_coverage_deviation := quantile_coverage - quantile_level] - - # summarise coverage values according to `by` and cleanup -------------------- - # reset column order - new_metrics <- c("interval_coverage", "interval_coverage_deviation", - "quantile_coverage", "quantile_coverage_deviation") - setcolorder(forecast, unique(c(forecast_cols, "interval_range", new_metrics))) - # remove forecast class and convert to regular data.table - forecast <- as.data.table(forecast) - by <- unique(c(by, "quantile_level", "interval_range")) - # summarise - forecast <- forecast[, lapply(.SD, mean), by = by, .SDcols = new_metrics] - return(forecast[]) -} - - -#' @title Count number of available forecasts -#' -#' @description -#' Given a data set with forecasts, this function counts the number of -#' available forecasts. -#' The level of grouping can be specified using the `by` argument (e.g. to -#' count the number of forecasts per model, or the number of forecasts per -#' model and location). -#' This is useful to determine whether there are any missing forecasts. -#' -#' @param by character vector or `NULL` (the default) that denotes the -#' categories over which the number of forecasts should be counted. -#' By default this will be the unit of a single forecast (i.e. -#' all available columns (apart from a few "protected" columns such as -#' 'predicted' and 'observed') plus "quantile_level" or "sample_id" where -#' present). -#' -#' @param collapse character vector (default: `c("quantile_level", "sample_id"`) -#' with names of categories for which the number of rows should be collapsed -#' to one when counting. For example, a single forecast is usually represented -#' by a set of several quantiles or samples and collapsing these to one makes -#' sure that a single forecast only gets counted once. Setting -#' `collapse = c()` would mean that all quantiles / samples would be counted -#' as individual forecasts. -#' -#' @return A data.table with columns as specified in `by` and an additional -#' column "count" with the number of forecasts. -#' -#' @inheritParams score -#' @importFrom data.table .I .N nafill -#' @export -#' @keywords gain-insights -#' @examples -#' \dontshow{ -#' data.table::setDTthreads(2) # restricts number of cores used on CRAN -#' } -#' -#' library(magrittr) # pipe operator -#' example_quantile %>% -#' as_forecast_quantile() %>% -#' get_forecast_counts(by = c("model", "target_type")) -get_forecast_counts <- function(forecast, - by = get_forecast_unit(forecast), - collapse = c("quantile_level", "sample_id")) { - forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE) - forecast_unit <- get_forecast_unit(forecast) - assert_subset(by, names(forecast), empty.ok = FALSE) - forecast <- as.data.table(forecast) - - # collapse several rows to 1, e.g. treat a set of 10 quantiles as one, - # because they all belong to one single forecast that should be counted once - collapse_by <- setdiff( - c(forecast_unit, "quantile_level", "sample_id"), - collapse - ) - # filter "quantile_level", "sample" if in `collapse_by`, but not the forecast - collapse_by <- intersect(collapse_by, names(forecast)) - - forecast <- forecast[forecast[, .I[1], by = collapse_by]$V1] - - # count number of rows = number of forecasts - out <- forecast[, .(count = .N), by = by] - - # make sure that all combinations in "by" are included in the output (with - # count = 0). To achieve that, take unique values in `forecast` and expand grid - col_vecs <- unclass(out) - col_vecs$count <- NULL - col_vecs <- lapply(col_vecs, unique) - out_empty <- expand.grid(col_vecs, stringsAsFactors = FALSE) - - out <- merge(out, out_empty, by = by, all.y = TRUE) - out[, count := nafill(count, fill = 0)] - - return(out[]) -} diff --git a/R/utils_data_handling.R b/R/helper-quantile-interval-range.R similarity index 100% rename from R/utils_data_handling.R rename to R/helper-quantile-interval-range.R diff --git a/R/metrics-binary.R b/R/metrics-binary.R index 49cc800c9..148cc421b 100644 --- a/R/metrics-binary.R +++ b/R/metrics-binary.R @@ -1,3 +1,38 @@ +#' @title Assert that inputs are correct for binary forecast +#' @description +#' Function assesses whether the inputs correspond to the +#' requirements for scoring binary forecasts. +#' @param observed Input to be checked. Should be a factor of length n with +#' exactly two levels, holding the observed values. +#' The highest factor level is assumed to be the reference level. This means +#' that `predicted` represents the probability that the observed value is +#' equal to the highest factor level. +#' @param predicted Input to be checked. `predicted` should be a vector of +#' length n, holding probabilities. Alternatively, `predicted` can be a matrix +#' of size n x 1. Values represent the probability that +#' the corresponding value in `observed` will be equal to the highest +#' available factor level. +#' @importFrom checkmate assert assert_factor +#' @inherit document_assert_functions return +#' @keywords internal_input_check +assert_input_binary <- function(observed, predicted) { + assert_factor(observed, n.levels = 2, min.len = 1) + assert_numeric(predicted, lower = 0, upper = 1) + assert_dims_ok_point(observed, predicted) + return(invisible(NULL)) +} + + +#' @title Check that inputs are correct for binary forecast +#' @inherit assert_input_binary params description +#' @inherit document_check_functions return +#' @keywords internal_input_check +check_input_binary <- function(observed, predicted) { + result <- check_try(assert_input_binary(observed, predicted)) + return(result) +} + + #' Metrics for binary outcomes #' #' @details diff --git a/R/metrics-range.R b/R/metrics-interval-range.R similarity index 68% rename from R/metrics-range.R rename to R/metrics-interval-range.R index 2f8688951..947e05c30 100644 --- a/R/metrics-range.R +++ b/R/metrics-interval-range.R @@ -1,6 +1,72 @@ -################################################################################ -# Metrics with a one-to-one relationship between input and score -################################################################################ +# NOTE: the interval range format is only used internally. + + + +#' @title Assert that inputs are correct for interval-based forecast +#' @description +#' Function assesses whether the inputs correspond to the +#' requirements for scoring interval-based forecasts. +#' @param lower Input to be checked. Should be a numeric vector of size n that +#' holds the predicted value for the lower bounds of the prediction intervals. +#' @param upper Input to be checked. Should be a numeric vector of size n that +#' holds the predicted value for the upper bounds of the prediction intervals. +#' @param interval_range Input to be checked. Should be a vector of size n that +#' denotes the interval range in percent. E.g. a value of 50 denotes a +#' (25%, 75%) prediction interval. +#' @importFrom cli cli_warn cli_abort +#' @inherit document_assert_functions params return +#' @keywords internal_input_check +assert_input_interval <- function(observed, lower, upper, interval_range) { + + assert(check_numeric_vector(observed, min.len = 1)) + n <- length(observed) + assert(check_numeric_vector(lower, len = n)) + assert(check_numeric_vector(upper, len = n)) + assert( + check_numeric_vector(interval_range, len = 1, lower = 0, upper = 100), + check_numeric_vector(interval_range, len = n, lower = 0, upper = 100) + ) + + diff <- upper - lower + diff <- diff[!is.na(diff)] + if (any(diff < 0)) { + cli_abort( + c( + "!" = "All values in `upper` need to be greater than or equal to + the corresponding values in `lower`" + ) + ) + } + if (any(interval_range > 0 & interval_range < 1, na.rm = TRUE)) { + #nolint start: keyword_quote_linter + cli_warn( + c( + "!" = "Found interval ranges between 0 and 1. Are you sure that's + right? An interval range of 0.5 e.g. implies a (49.75%, 50.25%) + prediction interval.", + "i" = "If you want to score a (25%, 75%) prediction interval, set + `interval_range = 50`." + ), + .frequency = "once", + .frequency_id = "small_interval_range" + ) + #nolint end + } + return(invisible(NULL)) +} + + +#' @title Check that inputs are correct for interval-based forecast +#' @inherit assert_input_interval params description +#' @inherit check_input_sample return description +#' @keywords internal_input_check +check_input_interval <- function(observed, lower, upper, interval_range) { + result <- check_try( + assert_input_interval(observed, lower, upper, interval_range) + ) + return(result) +} + #' @title Interval score #' diff --git a/R/metrics-nominal.R b/R/metrics-nominal.R index 63afc6c75..e01ed3d62 100644 --- a/R/metrics-nominal.R +++ b/R/metrics-nominal.R @@ -1,3 +1,68 @@ +#' @title Assert that inputs are correct for nominal forecasts +#' @description Function assesses whether the inputs correspond to the +#' requirements for scoring nominal forecasts. +#' @param observed Input to be checked. Should be a factor of length n with +#' N levels holding the observed values. n is the number of observations and +#' N is the number of possible outcomes the observed values can assume. +#' output) +#' @param predicted Input to be checked. Should be nxN matrix of predictive +#' quantiles, n (number of rows) being the number of data points and N +#' (number of columns) the number of possible outcomes the observed values +#' can assume. +#' If `observed` is just a single number, then predicted can just be a +#' vector of size N. +#' @param predicted Input to be checked. `predicted` should be a vector of +#' length n, holding probabilities. Alternatively, `predicted` can be a matrix +#' of size n x 1. Values represent the probability that +#' the corresponding value in `observed` will be equal to the highest +#' available factor level. +#' @param predicted_label Factor of length N with N levels, where N is the +#' number of possible outcomes the observed values can assume. +#' @importFrom checkmate assert_factor assert_numeric assert_set_equal +#' @inherit document_assert_functions return +#' @keywords internal_input_check +assert_input_nominal <- function(observed, predicted, predicted_label) { + # observed + assert_factor(observed, min.len = 1, min.levels = 2) + levels <- levels(observed) + n <- length(observed) + N <- length(levels) + + # predicted label + assert_factor( + predicted_label, len = N, + any.missing = FALSE, empty.levels.ok = FALSE + ) + assert_set_equal(levels(observed), levels(predicted_label)) + + # predicted + assert_numeric(predicted, min.len = 1, lower = 0, upper = 1) + if (n == 1) { + assert( + # allow one of two options + check_vector(predicted, len = N), + check_matrix(predicted, nrows = n, ncols = N) + ) + summed_predictions <- .rowSums(predicted, m = 1, n = N, na.rm = TRUE) + } else { + assert_matrix(predicted, nrows = n) + summed_predictions <- round(rowSums(predicted, na.rm = TRUE), 10) # avoid numeric errors + } + if (!all(summed_predictions == 1)) { + #nolint start: keyword_quote_linter object_usage_linter + row_indices <- as.character(which(summed_predictions != 1)) + cli_abort( + c( + `!` = "Probabilities belonging to a single forecast must sum to one", + `i` = "Found issues in row{?s} {row_indices} of {.var predicted}" + ) + ) + #nolint end + } + return(invisible(NULL)) +} + + #' Log score for nominal outcomes #' #' @description diff --git a/R/metrics-point.R b/R/metrics-point.R new file mode 100644 index 000000000..f31dc9030 --- /dev/null +++ b/R/metrics-point.R @@ -0,0 +1,73 @@ +#' @title Assert that inputs are correct for point forecast +#' @description +#' Function assesses whether the inputs correspond to the +#' requirements for scoring point forecasts. +#' @param predicted Input to be checked. Should be a numeric vector with the +#' predicted values of size n. +#' @inherit document_assert_functions params return +#' @keywords internal_input_check +assert_input_point <- function(observed, predicted) { + assert(check_numeric(observed)) + assert(check_numeric(predicted)) + assert(check_dims_ok_point(observed, predicted)) + return(invisible(NULL)) +} + + +#' @title Check that inputs are correct for point forecast +#' @inherit assert_input_point params description +#' @inherit document_check_functions return +#' @keywords internal_input_check +check_input_point <- function(observed, predicted) { + result <- check_try(assert_input_point(observed, predicted)) + return(result) +} + + +#' @title Assert Inputs Have Matching Dimensions +#' @description +#' Function assesses whether input dimensions match. In the +#' following, n is the number of observations / forecasts. Scalar values may +#' be repeated to match the length of the other input. +#' Allowed options are therefore: +#' - `observed` is vector of length 1 or length n +#' - `predicted` is: +#' - a vector of of length 1 or length n +#' - a matrix with n rows and 1 column +#' @inherit assert_input_binary +#' @inherit document_assert_functions return +#' @importFrom checkmate assert_vector check_matrix check_vector assert +#' @importFrom cli cli_abort +#' @keywords internal_input_check +assert_dims_ok_point <- function(observed, predicted) { + assert_vector(observed, min.len = 1) + n_obs <- length(observed) + assert( + check_vector(predicted, min.len = 1, strict = TRUE), + check_matrix(predicted, ncols = 1, nrows = n_obs) + ) + n_pred <- length(as.vector(predicted)) + # check that both are either of length 1 or of equal length + if ((n_obs != 1) && (n_pred != 1) && (n_obs != n_pred)) { + #nolint start: keyword_quote_linter object_usage_linter + cli_abort( + c( + "!" = "`observed` and `predicted` must either be of length 1 or + of equal length.", + "i" = "Found {n_obs} and {n_pred}." + ) + ) + #nolint end + } + return(invisible(NULL)) +} + + +#' @title Check Inputs Have Matching Dimensions +#' @inherit assert_dims_ok_point params description +#' @inherit document_check_functions return +#' @keywords internal_input_check +check_dims_ok_point <- function(observed, predicted) { + result <- check_try(assert_dims_ok_point(observed, predicted)) + return(result) +} diff --git a/R/metrics-quantile.R b/R/metrics-quantile.R index 5b57c3821..096f64061 100644 --- a/R/metrics-quantile.R +++ b/R/metrics-quantile.R @@ -1,6 +1,58 @@ -################################################################################ -# Metrics with a many-to-one relationship between input and score -################################################################################ +#' @title Assert that inputs are correct for quantile-based forecast +#' @description +#' Function assesses whether the inputs correspond to the +#' requirements for scoring quantile-based forecasts. +#' @param predicted Input to be checked. Should be nxN matrix of predictive +#' quantiles, n (number of rows) being the number of data points and N +#' (number of columns) the number of quantiles per forecast. +#' If `observed` is just a single number, then predicted can just be a +#' vector of size N. +#' @param quantile_level Input to be checked. Should be a vector of size N that +#' denotes the quantile levels corresponding to the columns of the prediction +#' matrix. +#' @param unique_quantile_levels Whether the quantile levels are required to be +#' unique (`TRUE`, the default) or not (`FALSE`). +#' @importFrom checkmate assert assert_numeric check_matrix check_vector +#' @inherit document_assert_functions params return +#' @keywords internal_input_check +assert_input_quantile <- function(observed, predicted, quantile_level, + unique_quantile_levels = TRUE) { + assert_numeric(observed, min.len = 1) + n_obs <- length(observed) + + assert_numeric( + quantile_level, min.len = 1, lower = 0, upper = 1, + unique = unique_quantile_levels + ) + n_quantiles <- length(quantile_level) + if (n_obs == 1) { + assert( + # allow one of two options + check_numeric_vector(predicted, min.len = n_quantiles), + check_matrix(predicted, mode = "numeric", + nrows = n_obs, ncols = n_quantiles) + ) + assert(check_vector(quantile_level, len = length(predicted))) + } else { + assert( + check_matrix(predicted, mode = "numeric", + nrows = n_obs, ncols = n_quantiles) + ) + } + return(invisible(NULL)) +} + +#' @title Check that inputs are correct for quantile-based forecast +#' @inherit assert_input_quantile params description +#' @inherit check_input_sample return description +#' @keywords internal_input_check +check_input_quantile <- function(observed, predicted, quantile_level) { + result <- check_try( + assert_input_quantile(observed, predicted, quantile_level) + ) + return(result) +} + #' Weighted interval score (WIS) #' @description diff --git a/R/metrics-sample.R b/R/metrics-sample.R index 74479d1c1..c88c6f26e 100644 --- a/R/metrics-sample.R +++ b/R/metrics-sample.R @@ -1,3 +1,41 @@ +#' @title Assert that inputs are correct for sample-based forecast +#' @description +#' Function assesses whether the inputs correspond to the requirements for +#' scoring sample-based forecasts. +#' @param predicted Input to be checked. Should be a numeric nxN matrix of +#' predictive samples, n (number of rows) being the number of data points and +#' N (number of columns) the number of samples per forecast. +#' If `observed` is just a single number, then predicted values can just be a +#' vector of size N. +#' @importFrom checkmate assert assert_numeric check_matrix assert_matrix +#' @inherit document_assert_functions params return +#' @keywords internal_input_check +assert_input_sample <- function(observed, predicted) { + assert_numeric(observed, min.len = 1) + n_obs <- length(observed) + + if (n_obs == 1) { + assert( + # allow one of two options + check_numeric_vector(predicted, min.len = 1), + check_matrix(predicted, mode = "numeric", nrows = n_obs) + ) + } else { + assert_matrix(predicted, mode = "numeric", nrows = n_obs) + } + return(invisible(NULL)) +} + +#' @title Check that inputs are correct for sample-based forecast +#' @inherit assert_input_sample params description +#' @inherit document_check_functions return +#' @keywords internal_input_check +check_input_sample <- function(observed, predicted) { + result <- check_try(assert_input_sample(observed, predicted)) + return(result) +} + + #' @title Determine bias of forecasts #' #' @description @@ -379,3 +417,107 @@ mad_sample <- function(observed = NULL, predicted, ...) { sharpness <- apply(predicted, MARGIN = 1, mad, ...) return(sharpness) } + + +#' @title Probability integral transformation for counts +#' +#' @description Uses a Probability integral transformation (PIT) (or a +#' randomised PIT for integer forecasts) to +#' assess the calibration of predictive Monte Carlo samples. +#' +#' @details +#' Calibration or reliability of forecasts is the ability of a model to +#' correctly identify its own uncertainty in making predictions. In a model +#' with perfect calibration, the observed data at each time point look as if +#' they came from the predictive probability distribution at that time. +#' +#' Equivalently, one can inspect the probability integral transform of the +#' predictive distribution at time t, +#' +#' \deqn{ +#' u_t = F_t (x_t) +#' } +#' +#' where \eqn{x_t} is the observed data point at time \eqn{t \textrm{ in } t_1, +#' …, t_n}{t in t_1, …, t_n}, n being the number of forecasts, and \eqn{F_t} is +#' the (continuous) predictive cumulative probability distribution at time t. If +#' the true probability distribution of outcomes at time t is \eqn{G_t} then the +#' forecasts \eqn{F_t} are said to be ideal if \eqn{F_t = G_t} at all times t. +#' In that case, the probabilities \eqn{u_t} are distributed uniformly. +#' +#' In the case of discrete nonnegative outcomes such as incidence counts, +#' the PIT is no longer uniform even when forecasts are ideal. +#' In that case a randomised PIT can be used instead: +#' \deqn{ +#' u_t = P_t(k_t) + v * (P_t(k_t) - P_t(k_t - 1) ) +#' } +#' +#' where \eqn{k_t} is the observed count, \eqn{P_t(x)} is the predictive +#' cumulative probability of observing incidence k at time t, +#' \eqn{P_t (-1) = 0} by definition and v is standard uniform and independent +#' of k. If \eqn{P_t} is the true cumulative +#' probability distribution, then \eqn{u_t} is standard uniform. +#' +#' @param n_replicates The number of draws for the randomised PIT for +#' discrete predictions. Will be ignored if forecasts are continuous. +#' @inheritParams ae_median_sample +#' @return A vector with PIT-values. For continuous forecasts, the vector will +#' correspond to the length of `observed`. For integer forecasts, a +#' randomised PIT will be returned of length +#' `length(observed) * n_replicates`. +#' @seealso [get_pit()] +#' @importFrom stats runif +#' @importFrom cli cli_abort cli_inform +#' @examples +#' \dontshow{ +#' data.table::setDTthreads(2) # restricts number of cores used on CRAN +#' } +#' +#' ## continuous predictions +#' observed <- rnorm(20, mean = 1:20) +#' predicted <- replicate(100, rnorm(n = 20, mean = 1:20)) +#' pit <- pit_sample(observed, predicted) +#' plot_pit(pit) +#' +#' ## integer predictions +#' observed <- rpois(20, lambda = 1:20) +#' predicted <- replicate(100, rpois(n = 20, lambda = 1:20)) +#' pit <- pit_sample(observed, predicted, n_replicates = 30) +#' plot_pit(pit) +#' @export +#' @references +#' Claudia Czado, Tilmann Gneiting Leonhard Held (2009) Predictive model +#' assessment for count data. Biometrika, 96(4), 633-648. +# +#' Sebastian Funk, Anton Camacho, Adam J. Kucharski, Rachel Lowe, +#' Rosalind M. Eggo, W. John Edmunds (2019) Assessing the performance of +#' real-time epidemic forecasts: A case study of Ebola in the Western Area +#' region of Sierra Leone, 2014-15, \doi{10.1371/journal.pcbi.1006785} +#' @keywords metric +pit_sample <- function(observed, + predicted, + n_replicates = 100) { + assert_input_sample(observed = observed, predicted = predicted) + assert_number(n_replicates) + if (is.vector(predicted)) { + predicted <- matrix(predicted, nrow = 1) + } + + # calculate PIT-values ------------------------------------------------------- + n_pred <- ncol(predicted) + + # calculate emipirical cumulative distribution function as + # Portion of (y_observed <= y_predicted) + p_x <- rowSums(predicted <= observed) / n_pred + + # PIT calculation is different for integer and continuous predictions + if (get_type(predicted) == "integer") { + p_xm1 <- rowSums(predicted <= (observed - 1)) / n_pred + pit_values <- as.vector( + replicate(n_replicates, p_xm1 + runif(1) * (p_x - p_xm1)) + ) + } else { + pit_values <- p_x + } + return(pit_values) +} diff --git a/R/metrics-validate.R b/R/metrics-validate.R deleted file mode 100644 index a26eb25ff..000000000 --- a/R/metrics-validate.R +++ /dev/null @@ -1,41 +0,0 @@ -#' @title Validate metrics -#' -#' @description -#' This function validates whether the list of metrics is a list -#' of valid functions. -#' -#' The function is used in [score()] to make sure that all metrics are valid -#' functions. -#' -#' @param metrics A named list with metrics. Every element should be a scoring -#' function to be applied to the data. -#' @importFrom cli cli_warn -#' -#' @return -#' A named list of metrics, with those filtered out that are not -#' valid functions -#' @importFrom checkmate assert_list test_list check_function -#' @keywords internal_input_check -validate_metrics <- function(metrics) { - - assert_list(metrics, min.len = 1, names = "named") - - for (i in seq_along(metrics)) { - check_fun <- check_function(metrics[[i]]) - if (!isTRUE(check_fun)) { - #nolint start: keyword_quote_linter - cli_warn( - c( - "!" = "`Metrics` element number {i} is not a valid function." - ) - ) - #nolint end - names(metrics)[i] <- "scoringutils_delete" - } - } - metrics[names(metrics) == "scoringutils_delete"] <- NULL - - assert_list(metrics, min.len = 1, .var.name = "valid metrics") - - return(metrics) -} diff --git a/R/metrics.R b/R/metrics.R new file mode 100644 index 000000000..f94735b8a --- /dev/null +++ b/R/metrics.R @@ -0,0 +1,68 @@ +#' @title Select metrics from a list of functions +#' +#' @description +#' Helper function to return only the scoring rules selected by +#' the user from a list of possible functions. +#' +#' @param metrics A list of scoring functions. +#' @param select A character vector of scoring rules to select from the list. If +#' `select` is `NULL` (the default), all possible scoring rules are returned. +#' @param exclude A character vector of scoring rules to exclude from the list. +#' If `select` is not `NULL`, this argument is ignored. +#' @return A list of scoring functions. +#' @keywords handle-metrics +#' @importFrom checkmate assert_subset assert_list +#' @export +#' @examples +#' select_metrics( +#' metrics = get_metrics(example_binary), +#' select = "brier_score" +#' ) +#' select_metrics( +#' metrics = get_metrics(example_binary), +#' exclude = "log_score" +#' ) +select_metrics <- function(metrics, select = NULL, exclude = NULL) { + assert_character(x = c(select, exclude), null.ok = TRUE) + assert_list(metrics, names = "named") + allowed <- names(metrics) + + if (is.null(select) && is.null(exclude)) { + return(metrics) + } + if (is.null(select)) { + assert_subset(exclude, allowed) + select <- allowed[!allowed %in% exclude] + return(metrics[select]) + } + assert_subset(select, allowed) + return(metrics[select]) +} + +#' Get metrics +#' +#' @description +#' Generic function to to obtain default metrics availble for scoring or metrics +#' that were used for scoring. +#' +#' - If called on `forecast` object it returns a list of functions that can be +#' used for scoring. +#' - If called on a `scores` object (see [score()]), it returns a character vector +#' with the names of the metrics that were used for scoring. +#' +#' See the documentation for the actual methods in the `See Also` section below +#' for more details. Alternatively call `?get_metrics.` or +#' `?get_metrics.scores`. +#' +#' @param x A `forecast` or `scores` object. +#' @param ... Additional arguments passed to the method. +#' @details +#' See [as_forecast()] for more information on `forecast` objects and [score()] +#' for more information on `scores` objects. +#' +#' @family `get_metrics` functions +#' @keywords handle-metrics +#' @export +get_metrics <- function(x, ...) { + UseMethod("get_metrics") +} diff --git a/R/pairwise-comparisons.R b/R/pairwise-comparisons.R index d467f2e98..d42969f1c 100644 --- a/R/pairwise-comparisons.R +++ b/R/pairwise-comparisons.R @@ -619,3 +619,138 @@ add_relative_skill <- function( return(scores) } + + +#' @title Plot heatmap of pairwise comparisons +#' +#' @description +#' Creates a heatmap of the ratios or pvalues from a pairwise comparison +#' between models. +#' +#' @param comparison_result A data.frame as produced by +#' [get_pairwise_comparisons()]. +#' @param type Character vector of length one that is either +#' "mean_scores_ratio" or "pval". This denotes whether to +#' visualise the ratio or the p-value of the pairwise comparison. +#' Default is "mean_scores_ratio". +#' @importFrom ggplot2 ggplot aes geom_tile geom_text labs coord_cartesian +#' scale_fill_gradient2 theme_light element_text +#' @importFrom data.table as.data.table setnames rbindlist +#' @importFrom stats reorder +#' @importFrom ggplot2 labs coord_cartesian facet_wrap facet_grid theme +#' element_text element_blank +#' @return +#' A ggplot object with a heatmap of mean score ratios from pairwise +#' comparisons. +#' @export +#' @examples +#' library(ggplot2) +#' library(magrittr) # pipe operator +#' scores <- example_quantile %>% +#' as_forecast_quantile %>% +#' score() +#' pairwise <- get_pairwise_comparisons(scores, by = "target_type") +#' plot_pairwise_comparisons(pairwise, type = "mean_scores_ratio") + +#' facet_wrap(~target_type) + +plot_pairwise_comparisons <- function(comparison_result, + type = c("mean_scores_ratio", "pval")) { + comparison_result <- ensure_data.table(comparison_result) + type <- match.arg(type) + + relative_skill_metric <- grep( + "(?% +#' as_forecast_quantile %>% +#' score() +#' scores <- summarise_scores(scores, by = c("model", "target_type")) +#' scores <- summarise_scores( +#' scores, by = c("model", "target_type"), +#' fun = signif, digits = 2 +#' ) +#' +#' plot_heatmap(scores, x = "target_type", metric = "bias") + +plot_heatmap <- function(scores, + y = "model", + x, + metric) { + scores <- ensure_data.table(scores) + assert_subset(y, names(scores)) + assert_subset(x, names(scores)) + assert_subset(metric, names(scores)) + + plot <- ggplot( + scores, + aes( + y = .data[[y]], + x = .data[[x]], + fill = .data[[metric]] + ) + ) + + geom_tile() + + geom_text(aes(label = .data[[metric]])) + + scale_fill_gradient2(low = "steelblue", high = "salmon") + + theme_scoringutils() + + theme(axis.text.x = element_text( + angle = 90, vjust = 1, + hjust = 1 + )) + + coord_cartesian(expand = FALSE) + + return(plot) +} diff --git a/R/plot-wis.R b/R/plot-wis.R new file mode 100644 index 000000000..26a17c933 --- /dev/null +++ b/R/plot-wis.R @@ -0,0 +1,93 @@ +#' @title Plot contributions to the weighted interval score +#' +#' @description +#' Visualise the components of the weighted interval score: penalties for +#' over-prediction, under-prediction and for high dispersion (lack of +#' sharpness). +#' +#' @param scores A data.table of scores based on quantile forecasts as +#' produced by [score()] and summarised using [summarise_scores()]. +#' @param x The variable from the scores you want to show on the x-Axis. +#' Usually this will be "model". +#' @param relative_contributions Logical. Show relative contributions instead +#' of absolute contributions? Default is `FALSE` and this functionality is not +#' available yet. +#' @param flip Boolean (default is `FALSE`), whether or not to flip the axes. +#' @return A ggplot object showing a contributions from the three components of +#' the weighted interval score. +#' @importFrom ggplot2 ggplot aes geom_linerange facet_wrap labs +#' scale_fill_discrete coord_flip +#' theme theme_light unit guides guide_legend .data +#' @importFrom data.table melt +#' @importFrom checkmate assert_subset assert_logical +#' @return A ggplot object with a visualisation of the WIS decomposition +#' @export +#' @examples +#' library(ggplot2) +#' library(magrittr) # pipe operator +#' scores <- example_quantile %>% +#' as_forecast_quantile %>% +#' score() +#' scores <- summarise_scores(scores, by = c("model", "target_type")) +#' +#' plot_wis(scores, +#' x = "model", +#' relative_contributions = TRUE +#' ) + +#' facet_wrap(~target_type) +#' plot_wis(scores, +#' x = "model", +#' relative_contributions = FALSE +#' ) + +#' facet_wrap(~target_type, scales = "free_x") +#' @references +#' Bracher J, Ray E, Gneiting T, Reich, N (2020) Evaluating epidemic forecasts +#' in an interval format. + +plot_wis <- function(scores, + x = "model", + relative_contributions = FALSE, + flip = FALSE) { + # input checks + scores <- ensure_data.table(scores) + wis_components <- c("overprediction", "underprediction", "dispersion") + assert(check_columns_present(scores, wis_components)) + assert_subset(x, names(scores)) + assert_logical(relative_contributions, len = 1) + assert_logical(flip, len = 1) + + scores <- melt( + scores, + measure.vars = wis_components, + variable.name = "wis_component_name", + value.name = "component_value" + ) + + # stack or fill the geom_col position + col_position <- ifelse(relative_contributions, "fill", "stack") + + plot <- ggplot(scores, aes(y = .data[[x]])) + + geom_col( + position = col_position, + aes(x = component_value, fill = wis_component_name) + ) + + theme_scoringutils() + + scale_fill_discrete(type = c("#DF536B", "#61D04F", "#2297E6")) + + guides(fill = guide_legend(title = "WIS component")) + + xlab("WIS contributions") + + if (flip) { + plot <- plot + + theme( + panel.spacing = unit(4, "mm"), + axis.text.x = element_text( + angle = 90, + vjust = 1, + hjust = 1 + ) + ) + + coord_flip() + } + + return(plot) +} diff --git a/R/plot.R b/R/plot.R deleted file mode 100644 index e382dcc55..000000000 --- a/R/plot.R +++ /dev/null @@ -1,750 +0,0 @@ -#' @title Plot contributions to the weighted interval score -#' -#' @description -#' Visualise the components of the weighted interval score: penalties for -#' over-prediction, under-prediction and for high dispersion (lack of -#' sharpness). -#' -#' @param scores A data.table of scores based on quantile forecasts as -#' produced by [score()] and summarised using [summarise_scores()]. -#' @param x The variable from the scores you want to show on the x-Axis. -#' Usually this will be "model". -#' @param relative_contributions Logical. Show relative contributions instead -#' of absolute contributions? Default is `FALSE` and this functionality is not -#' available yet. -#' @param flip Boolean (default is `FALSE`), whether or not to flip the axes. -#' @return A ggplot object showing a contributions from the three components of -#' the weighted interval score. -#' @importFrom ggplot2 ggplot aes geom_linerange facet_wrap labs -#' scale_fill_discrete coord_flip -#' theme theme_light unit guides guide_legend .data -#' @importFrom data.table melt -#' @importFrom checkmate assert_subset assert_logical -#' @return A ggplot object with a visualisation of the WIS decomposition -#' @export -#' @examples -#' library(ggplot2) -#' library(magrittr) # pipe operator -#' scores <- example_quantile %>% -#' as_forecast_quantile %>% -#' score() -#' scores <- summarise_scores(scores, by = c("model", "target_type")) -#' -#' plot_wis(scores, -#' x = "model", -#' relative_contributions = TRUE -#' ) + -#' facet_wrap(~target_type) -#' plot_wis(scores, -#' x = "model", -#' relative_contributions = FALSE -#' ) + -#' facet_wrap(~target_type, scales = "free_x") -#' @references -#' Bracher J, Ray E, Gneiting T, Reich, N (2020) Evaluating epidemic forecasts -#' in an interval format. - -plot_wis <- function(scores, - x = "model", - relative_contributions = FALSE, - flip = FALSE) { - # input checks - scores <- ensure_data.table(scores) - wis_components <- c("overprediction", "underprediction", "dispersion") - assert(check_columns_present(scores, wis_components)) - assert_subset(x, names(scores)) - assert_logical(relative_contributions, len = 1) - assert_logical(flip, len = 1) - - scores <- melt(scores, - measure.vars = wis_components, - variable.name = "wis_component_name", - value.name = "component_value" - ) - - # stack or fill the geom_col position - col_position <- ifelse(relative_contributions, "fill", "stack") - - plot <- ggplot(scores, aes(y = .data[[x]])) + - geom_col( - position = col_position, - aes(x = component_value, fill = wis_component_name) - ) + - theme_scoringutils() + - scale_fill_discrete(type = c("#DF536B", "#61D04F", "#2297E6")) + - guides(fill = guide_legend(title = "WIS component")) + - xlab("WIS contributions") - - if (flip) { - plot <- plot + - theme( - panel.spacing = unit(4, "mm"), - axis.text.x = element_text( - angle = 90, - vjust = 1, - hjust = 1 - ) - ) + - coord_flip() - } - - return(plot) -} - - -#' @title Create a heatmap of a scoring metric -#' -#' @description -#' This function can be used to create a heatmap of one metric across different -#' groups, e.g. the interval score obtained by several forecasting models in -#' different locations. -#' -#' @param scores A data.frame of scores based on quantile forecasts as -#' produced by [score()]. -#' @param y The variable from the scores you want to show on the y-Axis. The -#' default for this is "model" -#' @param x The variable from the scores you want to show on the x-Axis. This -#' could be something like "horizon", or "location" -#' @param metric String, the metric that determines the value and colour shown -#' in the tiles of the heatmap. -#' @return A ggplot object showing a heatmap of the desired metric -#' @importFrom data.table setDT `:=` -#' @importFrom ggplot2 ggplot aes geom_tile geom_text .data -#' scale_fill_gradient2 labs element_text coord_cartesian -#' @importFrom checkmate assert_subset -#' @export -#' @examples -#' library(magrittr) # pipe operator -#' scores <- example_quantile %>% -#' as_forecast_quantile %>% -#' score() -#' scores <- summarise_scores(scores, by = c("model", "target_type")) -#' scores <- summarise_scores( -#' scores, by = c("model", "target_type"), -#' fun = signif, digits = 2 -#' ) -#' -#' plot_heatmap(scores, x = "target_type", metric = "bias") - -plot_heatmap <- function(scores, - y = "model", - x, - metric) { - scores <- ensure_data.table(scores) - assert_subset(y, names(scores)) - assert_subset(x, names(scores)) - assert_subset(metric, names(scores)) - - plot <- ggplot( - scores, - aes( - y = .data[[y]], - x = .data[[x]], - fill = .data[[metric]] - ) - ) + - geom_tile() + - geom_text(aes(label = .data[[metric]])) + - scale_fill_gradient2(low = "steelblue", high = "salmon") + - theme_scoringutils() + - theme(axis.text.x = element_text( - angle = 90, vjust = 1, - hjust = 1 - )) + - coord_cartesian(expand = FALSE) - - return(plot) -} - - -#' @title Plot interval coverage -#' -#' @description -#' Plot interval coverage values (see [get_coverage()] for more information). -#' -#' @param coverage A data frame of coverage values as produced by -#' [get_coverage()]. -#' @param colour According to which variable shall the graphs be coloured? -#' Default is "model". -#' @return ggplot object with a plot of interval coverage -#' @importFrom ggplot2 ggplot scale_colour_manual scale_fill_manual .data -#' facet_wrap facet_grid geom_polygon geom_line -#' @importFrom checkmate assert_subset -#' @importFrom data.table dcast -#' @export -#' @examples -#' \dontshow{ -#' data.table::setDTthreads(2) # restricts number of cores used on CRAN -#' } -#' example <- as_forecast_quantile(example_quantile) -#' coverage <- get_coverage(example, by = "model") -#' plot_interval_coverage(coverage) -plot_interval_coverage <- function(coverage, - colour = "model") { - coverage <- ensure_data.table(coverage) - assert_subset(colour, names(coverage)) - - # in case quantile columns are present, remove them and then take unique - # values. This doesn't visually affect the plot, but prevents lines from being - # drawn twice. - del <- c("quantile_level", "quantile_coverage", "quantile_coverage_deviation") - suppressWarnings(coverage[, eval(del) := NULL]) - coverage <- unique(coverage) - - ## overall model calibration - empirical interval coverage - p1 <- ggplot(coverage, aes( - x = interval_range, - colour = .data[[colour]] - )) + - geom_polygon( - data = data.frame( - x = c(0, 0, 100), - y = c(0, 100, 100), - g = c("o", "o", "o"), - stringsAsFactors = TRUE - ), - aes( - x = x, y = y, group = g, - fill = g - ), - alpha = 0.05, - colour = "white", - fill = "olivedrab3" - ) + - geom_line(aes(y = interval_range), - colour = "grey", - linetype = "dashed" - ) + - geom_line(aes(y = interval_coverage * 100)) + - theme_scoringutils() + - ylab("% Obs inside interval") + - xlab("Nominal interval coverage") + - coord_cartesian(expand = FALSE) - - return(p1) -} - -#' @title Plot quantile coverage -#' -#' @description -#' Plot quantile coverage values (see [get_coverage()] for more information). -#' -#' @inheritParams plot_interval_coverage -#' @param colour String, according to which variable shall the graphs be -#' coloured? Default is "model". -#' @return A ggplot object with a plot of interval coverage -#' @importFrom ggplot2 ggplot scale_colour_manual scale_fill_manual .data aes -#' scale_y_continuous geom_line -#' @importFrom checkmate assert_subset assert_data_frame -#' @importFrom data.table dcast -#' @export -#' @examples -#' example <- as_forecast_quantile(example_quantile) -#' coverage <- get_coverage(example, by = "model") -#' plot_quantile_coverage(coverage) - -plot_quantile_coverage <- function(coverage, - colour = "model") { - coverage <- assert_data_frame(coverage) - assert_subset(colour, names(coverage)) - - p2 <- ggplot( - data = coverage, - aes(x = quantile_level, colour = .data[[colour]]) - ) + - geom_polygon( - data = data.frame( - x = c( - 0, 0.5, 0.5, - 0.5, 0.5, 1 - ), - y = c( - 0, 0, 0.5, - 0.5, 1, 1 - ), - g = c("o", "o", "o"), - stringsAsFactors = TRUE - ), - aes( - x = x, y = y, group = g, - fill = g - ), - alpha = 0.05, - colour = "white", - fill = "olivedrab3" - ) + - geom_line(aes(y = quantile_level), - colour = "grey", - linetype = "dashed" - ) + - geom_line(aes(y = quantile_coverage)) + - theme_scoringutils() + - xlab("Quantile level") + - ylab("% Obs below quantile level") + - scale_y_continuous( - labels = function(x) { - paste(100 * x) - } - ) + - coord_cartesian(expand = FALSE) - - return(p2) -} - -#' @title Plot heatmap of pairwise comparisons -#' -#' @description -#' Creates a heatmap of the ratios or pvalues from a pairwise comparison -#' between models. -#' -#' @param comparison_result A data.frame as produced by -#' [get_pairwise_comparisons()]. -#' @param type Character vector of length one that is either -#' "mean_scores_ratio" or "pval". This denotes whether to -#' visualise the ratio or the p-value of the pairwise comparison. -#' Default is "mean_scores_ratio". -#' @importFrom ggplot2 ggplot aes geom_tile geom_text labs coord_cartesian -#' scale_fill_gradient2 theme_light element_text -#' @importFrom data.table as.data.table setnames rbindlist -#' @importFrom stats reorder -#' @importFrom ggplot2 labs coord_cartesian facet_wrap facet_grid theme -#' element_text element_blank -#' @return -#' A ggplot object with a heatmap of mean score ratios from pairwise -#' comparisons. -#' @export -#' @examples -#' library(ggplot2) -#' library(magrittr) # pipe operator -#' scores <- example_quantile %>% -#' as_forecast_quantile %>% -#' score() -#' pairwise <- get_pairwise_comparisons(scores, by = "target_type") -#' plot_pairwise_comparisons(pairwise, type = "mean_scores_ratio") + -#' facet_wrap(~target_type) - -plot_pairwise_comparisons <- function(comparison_result, - type = c("mean_scores_ratio", "pval")) { - comparison_result <- ensure_data.table(comparison_result) - type <- match.arg(type) - - relative_skill_metric <- grep( - "(?% -#' as_forecast_quantile() %>% -#' get_pit(by = "model") -#' plot_pit(pit, breaks = seq(0.1, 1, 0.1)) -#' -#' # sample-based pit -#' pit <- example_sample_discrete %>% -#' as_forecast_sample %>% -#' get_pit(by = "model") -#' plot_pit(pit) -#' @importFrom ggplot2 ggplot aes xlab ylab geom_histogram stat theme_light after_stat -#' @importFrom checkmate assert check_set_equal check_number -#' @export - -plot_pit <- function(pit, - num_bins = "auto", - breaks = NULL) { - assert( - check_set_equal(num_bins, "auto"), - check_number(num_bins, lower = 1) - ) - assert_numeric(breaks, lower = 0, upper = 1, null.ok = TRUE) - - # vector-format is always sample-based, for data.frames there are two options - if ("quantile_level" %in% names(pit)) { - type <- "quantile-based" - } else { - type <- "sample-based" - } - - # use breaks if explicitly given, otherwise assign based on number of bins - if (!is.null(breaks)) { - plot_quantiles <- unique(c(0, breaks, 1)) - } else if (is.null(num_bins) || num_bins == "auto") { - # automatically set number of bins - if (type == "sample-based") { - num_bins <- 10 - width <- 1 / num_bins - plot_quantiles <- seq(0, 1, width) - } - if (type == "quantile-based") { - plot_quantiles <- unique(c(0, pit$quantile_level, 1)) - } - } else { - # if num_bins is explicitly given - width <- 1 / num_bins - plot_quantiles <- seq(0, 1, width) - } - - # function for data.frames - if (is.data.frame(pit)) { - facet_cols <- get_forecast_unit(pit) - formula <- as.formula(paste("~", paste(facet_cols, collapse = "+"))) - - # quantile version - if (type == "quantile-based") { - hist <- ggplot( - data = pit[quantile_level %in% plot_quantiles], - aes(x = quantile_level, y = pit_value) - ) + - geom_col(position = "dodge", colour = "grey") + - facet_wrap(formula) - } - - if (type == "sample-based") { - hist <- ggplot( - data = pit, - aes(x = pit_value) - ) + - geom_histogram(aes(y = after_stat(width * density)), - breaks = plot_quantiles, - colour = "grey" - ) + - facet_wrap(formula) - } - } else { - # non data.frame version - hist <- ggplot( - data = data.frame(x = pit, stringsAsFactors = TRUE), - aes(x = x) - ) + - geom_histogram(aes(y = after_stat(width * density)), - breaks = plot_quantiles, - colour = "grey" - ) - } - - hist <- hist + - xlab("PIT") + - ylab("Frequency") + - theme_scoringutils() - - return(hist) -} - -#' @title Visualise the number of available forecasts -#' -#' @description -#' Visualise Where Forecasts Are Available. -#' @param forecast_counts A data.table (or similar) with a column `count` -#' holding forecast counts, as produced by [get_forecast_counts()]. -#' @param x Character vector of length one that denotes the name of the column -#' to appear on the x-axis of the plot. -#' @param y Character vector of length one that denotes the name of the column -#' to appear on the y-axis of the plot. Default is "model". -#' @param x_as_factor Logical (default is `TRUE`). Whether or not to convert -#' the variable on the x-axis to a factor. This has an effect e.g. if dates -#' are shown on the x-axis. -#' @param show_counts Logical (default is `TRUE`) that indicates whether -#' or not to show the actual count numbers on the plot. -#' @return A ggplot object with a plot of forecast counts -#' @importFrom ggplot2 ggplot scale_colour_manual scale_fill_manual -#' geom_tile scale_fill_gradient .data -#' @importFrom data.table dcast .I .N -#' @importFrom checkmate assert_subset assert_logical -#' @export -#' @examples -#' library(ggplot2) -#' library(magrittr) # pipe operator -#' forecast_counts <- example_quantile %>% -#' as_forecast_quantile %>% -#' get_forecast_counts(by = c("model", "target_type", "target_end_date")) -#' plot_forecast_counts( -#' forecast_counts, x = "target_end_date", show_counts = FALSE -#' ) + -#' facet_wrap("target_type") - -plot_forecast_counts <- function(forecast_counts, - x, - y = "model", - x_as_factor = TRUE, - show_counts = TRUE) { - - forecast_counts <- ensure_data.table(forecast_counts) - assert_subset(y, colnames(forecast_counts)) - assert_subset(x, colnames(forecast_counts)) - assert_logical(x_as_factor, len = 1) - assert_logical(show_counts, len = 1) - - if (x_as_factor) { - forecast_counts[, eval(x) := as.factor(get(x))] - } - - setnames(forecast_counts, old = "count", new = "Count") - - plot <- ggplot( - forecast_counts, - aes(y = .data[[y]], x = .data[[x]]) - ) + - geom_tile(aes(fill = `Count`), - width = 0.97, height = 0.97) + - scale_fill_gradient( - low = "grey95", high = "steelblue", - na.value = "lightgrey" - ) + - theme_scoringutils() + - theme( - axis.text.x = element_text( - angle = 90, vjust = 1, - hjust = 1 - ) - ) + - theme(panel.spacing = unit(2, "lines")) - if (show_counts) { - plot <- plot + - geom_text(aes(label = `Count`)) - } - return(plot) -} - - -#' @title Plot correlation between metrics -#' -#' @description -#' Plots a heatmap of correlations between different metrics. -#' -#' @param correlations A data.table of correlations between scores as produced -#' by [get_correlations()]. -#' @param digits A number indicating how many decimal places the correlations -#' should be rounded to. By default (`digits = NULL`) no rounding takes place. -#' @return -#' A ggplot object showing a coloured matrix of correlations between metrics. -#' @importFrom ggplot2 ggplot geom_tile geom_text aes scale_fill_gradient2 -#' element_text labs coord_cartesian theme element_blank -#' @importFrom data.table setDT melt -#' @importFrom checkmate assert_data_frame -#' @export -#' @return A ggplot object with a visualisation of correlations between metrics -#' @examples -#' library(magrittr) # pipe operator -#' scores <- example_quantile %>% -#' as_forecast_quantile %>% -#' score() -#' correlations <- scores %>% -#' summarise_scores() %>% -#' get_correlations() -#' plot_correlations(correlations, digits = 2) - -plot_correlations <- function(correlations, digits = NULL) { - - assert_data_frame(correlations) - metrics <- get_metrics.scores(correlations, error = TRUE) - - lower_triangle <- get_lower_tri(correlations[, .SD, .SDcols = metrics]) - - if (!is.null(digits)) { - lower_triangle <- round(lower_triangle, digits) - } - - - # check correlations is actually a matrix of correlations - col_present <- check_columns_present(correlations, "metric") - if (any(lower_triangle > 1, na.rm = TRUE) || !isTRUE(col_present)) { - #nolint start: keyword_quote_linter - cli_abort( - c( - "Found correlations > 1 or missing `metric` column.", - "i" = "Did you forget to call {.fn scoringutils::get_correlations}?" - ) - ) - #nolint end - } - - rownames(lower_triangle) <- colnames(lower_triangle) - - # get plot data.frame - plot_df <- data.table::as.data.table(lower_triangle)[, metric := metrics] - plot_df <- na.omit(data.table::melt(plot_df, id.vars = "metric")) - - # refactor levels according to the metrics - plot_df[, metric := factor(metric, levels = metrics)] - plot_df[, variable := factor(variable, rev(metrics))] - - plot <- ggplot(plot_df, aes( - x = variable, y = metric, - fill = value - )) + - geom_tile( - color = "white", - width = 0.97, height = 0.97 - ) + - geom_text(aes(y = metric, label = value)) + - scale_fill_gradient2( - low = "steelblue", mid = "white", - high = "salmon", - name = "Correlation", - breaks = c(-1, -0.5, 0, 0.5, 1) - ) + - theme_scoringutils() + - theme( - axis.text.x = element_text( - angle = 90, vjust = 1, - hjust = 1 - ) - ) + - labs(x = "", y = "") + - coord_cartesian(expand = FALSE) - return(plot) -} - - -# helper function to obtain lower triangle of matrix -get_lower_tri <- function(cormat) { - cormat[lower.tri(cormat)] <- NA - return(cormat) -} - - -#' @title Scoringutils ggplot2 theme -#' -#' @description -#' A theme for ggplot2 plots used in `scoringutils`. -#' @return A ggplot2 theme -#' @importFrom ggplot2 theme theme_minimal element_line `%+replace%` -#' @keywords plotting -#' @export -theme_scoringutils <- function() { - theme_minimal() %+replace% - theme(axis.line = element_line(colour = "grey80"), - axis.ticks = element_line(colour = "grey80"), - panel.grid.major = element_blank(), - panel.grid.minor = element_blank(), - panel.border = element_blank(), - panel.background = element_blank(), - legend.position = "bottom") -} diff --git a/R/print.R b/R/print.R deleted file mode 100644 index 58c775d84..000000000 --- a/R/print.R +++ /dev/null @@ -1,67 +0,0 @@ -#' @title Print information about a forecast object -#' @description -#' This function prints information about a forecast object, -#' including "Forecast type", "Score columns", -#' "Forecast unit". -#' -#' @param x A forecast object (a validated data.table with predicted and -#' observed values, see [as_forecast()]). -#' @param ... Additional arguments for [print()]. -#' @return Returns `x` invisibly. -#' @importFrom cli cli_inform cli_warn col_blue cli_text -#' @export -#' @keywords gain-insights -#' @examples -#' dat <- as_forecast_quantile(example_quantile) -#' print(dat) -print.forecast <- function(x, ...) { - - # get forecast type, forecast unit and score columns - forecast_type <- try( - do.call(get_forecast_type, list(forecast = x)), - silent = TRUE - ) - forecast_unit <- try( - do.call(get_forecast_unit, list(data = x)), - silent = TRUE - ) - - # Print forecast object information - if (inherits(forecast_type, "try-error")) { - cli_inform( - c( - "!" = "Could not determine forecast type due to error in validation." #nolint - ) - ) - } else { - cli_text( - col_blue( - "Forecast type: " - ), - "{forecast_type}" - ) - } - - if (inherits(forecast_unit, "try-error")) { - cli_inform( - c( - "!" = "Could not determine forecast unit." #nolint - ) - ) - } else { - cli_text( - col_blue( - "Forecast unit:" - ) - ) - cli_text( - "{forecast_unit}" - ) - } - - cat("\n") - - NextMethod() - - return(invisible(x)) -} diff --git a/R/score.R b/R/score.R index dc7199c7e..2ebb933be 100644 --- a/R/score.R +++ b/R/score.R @@ -105,161 +105,6 @@ score.default <- function(forecast, metrics, ...) { ) } -#' @importFrom stats na.omit -#' @importFrom data.table setattr copy -#' @rdname score -#' @export -score.forecast_binary <- function(forecast, metrics = get_metrics(forecast), ...) { - forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE) - metrics <- validate_metrics(metrics) - forecast <- as.data.table(forecast) - - scores <- apply_metrics( - forecast, metrics, - forecast$observed, forecast$predicted - ) - scores[, `:=`(predicted = NULL, observed = NULL)] - - scores <- as_scores(scores, metrics = names(metrics)) - return(scores[]) -} - - -#' @importFrom stats na.omit -#' @importFrom data.table setattr -#' @rdname score -#' @export -score.forecast_nominal <- function(forecast, metrics = get_metrics(forecast), ...) { - forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE) - forecast_unit <- get_forecast_unit(forecast) - metrics <- validate_metrics(metrics) - forecast <- as.data.table(forecast) - - # transpose the forecasts that belong to the same forecast unit - # make sure the labels and predictions are ordered in the same way - f_transposed <- forecast[, .( - predicted = list(predicted[order(predicted_label)]), - observed = unique(observed) - ), by = forecast_unit] - - observed <- f_transposed$observed - predicted <- do.call(rbind, f_transposed$predicted) - predicted_label <- sort(unique(forecast$predicted_label, na.last = TRUE)) - f_transposed[, c("observed", "predicted") := NULL] - - scores <- apply_metrics( - f_transposed, metrics, - observed, predicted, predicted_label, ... - ) - scores <- as_scores(scores, metrics = names(metrics)) - return(scores[]) -} - - -#' @importFrom Metrics se ae ape -#' @importFrom stats na.omit -#' @importFrom data.table setattr copy -#' @rdname score -#' @export -score.forecast_point <- function(forecast, metrics = get_metrics(forecast), ...) { - forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE) - metrics <- validate_metrics(metrics) - forecast <- as.data.table(forecast) - - scores <- apply_metrics( - forecast, metrics, - forecast$observed, forecast$predicted - ) - scores[, `:=`(predicted = NULL, observed = NULL)] - - scores <- as_scores(scores, metrics = names(metrics)) - return(scores[]) -} - -#' @importFrom stats na.omit -#' @importFrom data.table setattr copy -#' @rdname score -#' @export -score.forecast_sample <- function(forecast, metrics = get_metrics(forecast), ...) { - forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE) - forecast_unit <- get_forecast_unit(forecast) - metrics <- validate_metrics(metrics) - forecast <- as.data.table(forecast) - - # transpose the forecasts that belong to the same forecast unit - f_transposed <- forecast[, .(predicted = list(predicted), - observed = unique(observed), - scoringutils_N = length(list(sample_id))), - by = forecast_unit] - - # split according to number of samples and do calculations for different - # sample lengths separately - f_split <- split(f_transposed, f_transposed$scoringutils_N) - - split_result <- lapply(f_split, function(forecast) { - # create a matrix - observed <- forecast$observed - predicted <- do.call(rbind, forecast$predicted) - forecast[, c("observed", "predicted", "scoringutils_N") := NULL] - - forecast <- apply_metrics( - forecast, metrics, - observed, predicted - ) - return(forecast) - }) - scores <- rbindlist(split_result, fill = TRUE) - scores <- as_scores(scores, metrics = names(metrics)) - return(scores[]) -} - - -#' @importFrom stats na.omit -#' @importFrom data.table `:=` as.data.table rbindlist %like% setattr copy -#' @rdname score -#' @export -score.forecast_quantile <- function(forecast, metrics = get_metrics(forecast), ...) { - forecast <- clean_forecast(forecast, copy = TRUE, na.omit = TRUE) - forecast_unit <- get_forecast_unit(forecast) - metrics <- validate_metrics(metrics) - forecast <- as.data.table(forecast) - - # transpose the forecasts that belong to the same forecast unit - # make sure the quantiles and predictions are ordered in the same way - f_transposed <- forecast[, .( - predicted = list(predicted[order(quantile_level)]), - observed = unique(observed), - quantile_level = list(sort(quantile_level, na.last = TRUE)), - scoringutils_quantile_level = toString(sort(quantile_level, na.last = TRUE)) - ), by = forecast_unit] - - # split according to quantile_level lengths and do calculations for different - # quantile_level lengths separately. The function `wis()` assumes that all - # forecasts have the same quantile_levels - f_split <- split(f_transposed, f_transposed$scoringutils_quantile_level) - - split_result <- lapply(f_split, function(forecast) { - # create a matrix out of the list of predicted values and quantile_levels - observed <- forecast$observed - predicted <- do.call(rbind, forecast$predicted) - quantile_level <- unlist(unique(forecast$quantile_level)) - forecast[, c( - "observed", "predicted", "quantile_level", "scoringutils_quantile_level" - ) := NULL] - - forecast <- apply_metrics( - forecast, metrics, - observed, predicted, quantile_level - ) - return(forecast) - }) - scores <- rbindlist(split_result, fill = TRUE) - - scores <- as_scores(scores, metrics = names(metrics)) - - return(scores[]) -} - #' @title Apply a list of functions to a data table of forecasts #' @description @@ -290,74 +135,105 @@ apply_metrics <- function(forecast, metrics, ...) { } -#' Construct an object of class `scores` +#' @title Run a function safely #' @description -#' This function creates an object of class `scores` based on a -#' data.table or similar. -#' @param scores A data.table or similar with scores as produced by [score()]. -#' @param metrics A character vector with the names of the scores -#' (i.e. the names of the scoring rules used for scoring). -#' @param ... Additional arguments to [data.table::as.data.table()] +#' This is a wrapper/helper function designed to run a function safely +#' when it is not completely clear what arguments could be passed to the +#' function. +#' +#' All named arguments in `...` that are not accepted by `fun` are removed. +#' All unnamed arguments are passed on to the function. In case `fun` errors, +#' the error will be converted to a warning and `run_safely` returns `NULL`. +#' +#' `run_safely` can be useful when constructing functions to be used as +#' metrics in [score()]. +#' +#' @param ... Arguments to pass to `fun`. +#' @param fun A function to execute. +#' @param metric_name A character string with the name of the metric. Used to +#' provide a more informative warning message in case `fun` errors. +#' @importFrom cli cli_warn +#' @importFrom checkmate assert_function +#' @return The result of `fun` or `NULL` if `fun` errors #' @keywords internal -#' @importFrom data.table as.data.table setattr -#' @return An object of class `scores` #' @examples -#' \dontrun{ -#' df <- data.frame( -#' model = "A", -#' wis = "0.1" -#' ) -#' new_scores(df, "wis") -#' } -new_scores <- function(scores, metrics, ...) { - scores <- as.data.table(scores, ...) - class(scores) <- c("scores", class(scores)) - setattr(scores, "metrics", metrics) - return(scores[]) -} +#' f <- function(x) {x} +#' scoringutils:::run_safely(2, fun = f, metric_name = "f") +#' scoringutils:::run_safely(2, y = 3, fun = f, metric_name = "f") +#' scoringutils:::run_safely(fun = f, metric_name = "f") +#' scoringutils:::run_safely(y = 3, fun = f, metric_name = "f") +run_safely <- function(..., fun, metric_name) { + assert_function(fun) + args <- list(...) + # Check if the function accepts ... as an argument + if ("..." %in% names(formals(fun))) { + valid_args <- args + } else if (is.null(names(args))) { + # if no arguments are named, just pass all arguments on + valid_args <- args + } else { + # Identify the arguments that fun() accepts + possible_args <- names(formals(fun)) + # keep valid arguments as well as unnamed arguments + valid_args <- args[names(args) == "" | names(args) %in% possible_args] + } + result <- try(do.call(fun, valid_args), silent = TRUE) -#' Create an object of class `scores` from data -#' @description This convenience function wraps [new_scores()] and validates -#' the `scores` object. -#' @inherit new_scores params return -#' @importFrom checkmate assert_data_frame -#' @keywords internal -as_scores <- function(scores, metrics) { - assert_data_frame(scores) - present_metrics <- metrics[metrics %in% colnames(scores)] - scores <- new_scores(scores, present_metrics) - validate_scores(scores) - return(scores[]) + if (inherits(result, "try-error")) { + #nolint start: object_usage_linter + msg <- conditionMessage(attr(result, "condition")) + cli_warn( + c( + "!" = "Computation for {.var {metric_name}} failed. + Error: {msg}." + ) + ) + #nolint end + return(NULL) + } + return(result) } -#' Validate an object of class `scores` +#' @title Validate metrics +#' #' @description -#' This function validates an object of class `scores`, checking -#' that it has the correct class and that it has a `metrics` attribute. -#' @inheritParams new_scores -#' @returns Returns `NULL` invisibly -#' @importFrom checkmate assert_class assert_data_frame -#' @keywords internal -validate_scores <- function(scores) { - assert_data_frame(scores) - assert_class(scores, "scores") - # error if no metrics exists + - # throw warning if any of the metrics is not in the data - get_metrics.scores(scores, error = TRUE) - return(invisible(NULL)) -} - -#' @method `[` scores -#' @importFrom data.table setattr -#' @export -`[.scores` <- function(x, ...) { - ret <- NextMethod() - if (is.data.table(ret)) { - setattr(ret, "metrics", attr(x, "metrics")) - } else if (is.data.frame(ret)) { - attr(ret, "metrics") <- attr(x, "metrics") +#' This function validates whether the list of metrics is a list +#' of valid functions. +#' +#' The function is used in [score()] to make sure that all metrics are valid +#' functions. +#' +#' @param metrics A named list with metrics. Every element should be a scoring +#' function to be applied to the data. +#' @importFrom cli cli_warn +#' +#' @return +#' A named list of metrics, with those filtered out that are not +#' valid functions +#' @importFrom checkmate assert_list test_list check_function +#' @keywords internal_input_check +validate_metrics <- function(metrics) { + + assert_list(metrics, min.len = 1, names = "named") + + for (i in seq_along(metrics)) { + check_fun <- check_function(metrics[[i]]) + if (!isTRUE(check_fun)) { + #nolint start: keyword_quote_linter + cli_warn( + c( + "!" = "`Metrics` element number {i} is not a valid function." + ) + ) + #nolint end + names(metrics)[i] <- "scoringutils_delete" + } } - return(ret) + metrics[names(metrics) == "scoringutils_delete"] <- NULL + + assert_list(metrics, min.len = 1, .var.name = "valid metrics") + + return(metrics) } diff --git a/R/theme-scoringutils.R b/R/theme-scoringutils.R new file mode 100644 index 000000000..d24a76484 --- /dev/null +++ b/R/theme-scoringutils.R @@ -0,0 +1,18 @@ +#' @title Scoringutils ggplot2 theme +#' +#' @description +#' A theme for ggplot2 plots used in `scoringutils`. +#' @return A ggplot2 theme +#' @importFrom ggplot2 theme theme_minimal element_line `%+replace%` +#' @keywords plotting +#' @export +theme_scoringutils <- function() { + theme_minimal() %+replace% + theme(axis.line = element_line(colour = "grey80"), + axis.ticks = element_line(colour = "grey80"), + panel.grid.major = element_blank(), + panel.grid.minor = element_blank(), + panel.border = element_blank(), + panel.background = element_blank(), + legend.position = "bottom") +} diff --git a/R/convenience-functions.R b/R/transform-forecasts.R similarity index 83% rename from R/convenience-functions.R rename to R/transform-forecasts.R index 3d8196f78..8a379cb53 100644 --- a/R/convenience-functions.R +++ b/R/transform-forecasts.R @@ -239,44 +239,3 @@ log_shift <- function(x, offset = 0, base = exp(1)) { } log(x + offset, base = base) } - - -#' @title Set unit of a single forecast manually -#' -#' @description -#' Helper function to set the unit of a single forecast (i.e. the -#' combination of columns that uniquely define a single forecast) manually. -#' This simple function keeps the columns specified in `forecast_unit` (plus -#' additional protected columns, e.g. for observed values, predictions or -#' quantile levels) and removes duplicate rows. `set_forecast_unit()` will -#' mainly be called when constructing a `forecast` object (see [as_forecast()]) -#' via the `forecast_unit` argument there. -#' -#' If not done explicitly, `scoringutils` attempts to determine the unit -#' of a single forecast automatically by simply assuming that all column names -#' are relevant to determine the forecast unit. This may lead to unexpected -#' behaviour, so setting the forecast unit explicitly can help make the code -#' easier to debug and easier to read. -#' -#' @inheritParams as_forecast -#' @param forecast_unit Character vector with the names of the columns that -#' uniquely identify a single forecast. -#' @importFrom cli cli_warn -#' @return A data.table with only those columns kept that are relevant to -#' scoring or denote the unit of a single forecast as specified by the user. -#' @importFrom data.table ':=' is.data.table copy -#' @importFrom checkmate assert_character assert_subset -#' @keywords as_forecast -#' @examples -#' library(magrittr) # pipe operator -#' example_quantile %>% -#' scoringutils:::set_forecast_unit( -#' c("location", "target_end_date", "target_type", "horizon", "model") -#' ) -set_forecast_unit <- function(data, forecast_unit) { - data <- ensure_data.table(data) - assert_subset(forecast_unit, names(data), empty.ok = FALSE) - keep_cols <- c(get_protected_columns(data), forecast_unit) - out <- unique(data[, .SD, .SDcols = keep_cols]) - return(out) -} diff --git a/R/utils.R b/R/utils.R deleted file mode 100644 index aec13251d..000000000 --- a/R/utils.R +++ /dev/null @@ -1,78 +0,0 @@ -#' @title Run a function safely -#' @description -#' This is a wrapper function designed to run a function safely -#' when it is not completely clear what arguments could be passed to the -#' function. -#' -#' All named arguments in `...` that are not accepted by `fun` are removed. -#' All unnamed arguments are passed on to the function. In case `fun` errors, -#' the error will be converted to a warning and `run_safely` returns `NULL`. -#' -#' `run_safely` can be useful when constructing functions to be used as -#' metrics in [score()]. -#' -#' @param ... Arguments to pass to `fun`. -#' @param fun A function to execute. -#' @param metric_name A character string with the name of the metric. Used to -#' provide a more informative warning message in case `fun` errors. -#' @importFrom cli cli_warn -#' @importFrom checkmate assert_function -#' @return The result of `fun` or `NULL` if `fun` errors -#' @keywords internal -#' @examples -#' f <- function(x) {x} -#' scoringutils:::run_safely(2, fun = f, metric_name = "f") -#' scoringutils:::run_safely(2, y = 3, fun = f, metric_name = "f") -#' scoringutils:::run_safely(fun = f, metric_name = "f") -#' scoringutils:::run_safely(y = 3, fun = f, metric_name = "f") -run_safely <- function(..., fun, metric_name) { - assert_function(fun) - args <- list(...) - # Check if the function accepts ... as an argument - if ("..." %in% names(formals(fun))) { - valid_args <- args - } else if (is.null(names(args))) { - # if no arguments are named, just pass all arguments on - valid_args <- args - } else { - # Identify the arguments that fun() accepts - possible_args <- names(formals(fun)) - # keep valid arguments as well as unnamed arguments - valid_args <- args[names(args) == "" | names(args) %in% possible_args] - } - - result <- try(do.call(fun, valid_args), silent = TRUE) - - if (inherits(result, "try-error")) { - #nolint start: object_usage_linter - msg <- conditionMessage(attr(result, "condition")) - cli_warn( - c( - "!" = "Computation for {.var {metric_name}} failed. - Error: {msg}." - ) - ) - #nolint end - return(NULL) - } - return(result) -} - - -#' Ensure that an object is a `data.table` -#' @description -#' This function ensures that an object is a `data table`. -#' If the object is not a data table, it is converted to one. If the object -#' is a data table, a copy of the object is returned. -#' @param data An object to ensure is a data table. -#' @return A data.table/a copy of an existing data.table. -#' @keywords internal -#' @importFrom data.table copy is.data.table as.data.table -ensure_data.table <- function(data) { - if (is.data.table(data)) { - data <- copy(data) - } else { - data <- as.data.table(data) - } - return(data) -} diff --git a/R/z_globalVariables.R b/R/z-globalVariables.R similarity index 100% rename from R/z_globalVariables.R rename to R/z-globalVariables.R diff --git a/man/as_forecast.Rd b/man/as_forecast.Rd index 42f1525ef..8e1152151 100644 --- a/man/as_forecast.Rd +++ b/man/as_forecast.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/forecast.R +% Please edit documentation in R/class-forecast.R \name{as_forecast} \alias{as_forecast} \title{General information on creating a \code{forecast} object} @@ -32,9 +32,12 @@ returned: } } \description{ -There are several ``as_forecast_\if{html}{\out{}}()\verb{functions to process and validate a data.frame (or similar) or similar with forecasts and observations. If the input passes all input checks, those functions will be converted to a}forecast` object. A forecast object is a `data.table` with a -class `forecast` and an additional class that depends on the forecast type. -Every forecast type has its own `as_forecast_\if{html}{\out{}}()` function. +There are several \verb{as_forecast_()} functions to process and validate +a data.frame (or similar) or similar with forecasts and observations. If +the input passes all input checks, those functions will be converted +to a \code{forecast} object. A forecast object is a \code{data.table} with a +class \code{forecast} and an additional class that depends on the forecast type. +Every forecast type has its own \verb{as_forecast_()} function. See the details section below for more information on the expected input formats. diff --git a/man/as_forecast_binary.Rd b/man/as_forecast_binary.Rd index 160b3ffcb..4e11a0e16 100644 --- a/man/as_forecast_binary.Rd +++ b/man/as_forecast_binary.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/forecast.R +% Please edit documentation in R/class-forecast-binary.R \name{as_forecast_binary} \alias{as_forecast_binary} \title{Create a \code{forecast} object for binary forecasts} diff --git a/man/as_forecast_generic.Rd b/man/as_forecast_generic.Rd index 3895e756e..744f07748 100644 --- a/man/as_forecast_generic.Rd +++ b/man/as_forecast_generic.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/forecast.R +% Please edit documentation in R/class-forecast.R \name{as_forecast_generic} \alias{as_forecast_generic} \title{Common functionality for \verb{as_forecast_} functions} diff --git a/man/as_forecast_nominal.Rd b/man/as_forecast_nominal.Rd index badc43b8d..fc240f600 100644 --- a/man/as_forecast_nominal.Rd +++ b/man/as_forecast_nominal.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/forecast.R +% Please edit documentation in R/class-forecast-nominal.R \name{as_forecast_nominal} \alias{as_forecast_nominal} \title{Create a \code{forecast} object for nominal forecasts} diff --git a/man/as_forecast_point.Rd b/man/as_forecast_point.Rd index 983ca7d89..7ab9fae9d 100644 --- a/man/as_forecast_point.Rd +++ b/man/as_forecast_point.Rd @@ -1,5 +1,6 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/forecast.R +% Please edit documentation in R/class-forecast-point.R, +% R/class-forecast-quantile.R \name{as_forecast_point} \alias{as_forecast_point} \alias{as_forecast_point.default} diff --git a/man/as_forecast_quantile.Rd b/man/as_forecast_quantile.Rd index 96ffe1c71..13f88a51d 100644 --- a/man/as_forecast_quantile.Rd +++ b/man/as_forecast_quantile.Rd @@ -1,5 +1,6 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/forecast.R +% Please edit documentation in R/class-forecast-quantile.R, +% R/class-forecast-sample.R \name{as_forecast_quantile} \alias{as_forecast_quantile} \alias{as_forecast_quantile.default} diff --git a/man/as_forecast_sample.Rd b/man/as_forecast_sample.Rd index 78bc80952..f298061d0 100644 --- a/man/as_forecast_sample.Rd +++ b/man/as_forecast_sample.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/forecast.R +% Please edit documentation in R/class-forecast-sample.R \name{as_forecast_sample} \alias{as_forecast_sample} \title{Create a \code{forecast} object for sample-based forecasts} diff --git a/man/as_scores.Rd b/man/as_scores.Rd index 33b78b3b6..5c1e8589e 100644 --- a/man/as_scores.Rd +++ b/man/as_scores.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/score.R +% Please edit documentation in R/class-scores.R \name{as_scores} \alias{as_scores} \title{Create an object of class \code{scores} from data} diff --git a/man/assert_dims_ok_point.Rd b/man/assert_dims_ok_point.Rd index e49f12adf..7219c02cf 100644 --- a/man/assert_dims_ok_point.Rd +++ b/man/assert_dims_ok_point.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/check-inputs-scoring-functions.R +% Please edit documentation in R/metrics-point.R \name{assert_dims_ok_point} \alias{assert_dims_ok_point} \title{Assert Inputs Have Matching Dimensions} diff --git a/man/assert_forecast.Rd b/man/assert_forecast.Rd index 9b23f1225..0b73a05e0 100644 --- a/man/assert_forecast.Rd +++ b/man/assert_forecast.Rd @@ -1,18 +1,16 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/forecast.R -\name{assert_forecast} -\alias{assert_forecast} -\alias{assert_forecast.default} +% Please edit documentation in R/class-forecast-binary.R, +% R/class-forecast-point.R, R/class-forecast-quantile.R, +% R/class-forecast-sample.R, R/class-forecast.R +\name{assert_forecast.forecast_binary} \alias{assert_forecast.forecast_binary} \alias{assert_forecast.forecast_point} \alias{assert_forecast.forecast_quantile} \alias{assert_forecast.forecast_sample} +\alias{assert_forecast} +\alias{assert_forecast.default} \title{Assert that input is a forecast object and passes validations} \usage{ -assert_forecast(forecast, forecast_type = NULL, verbose = TRUE, ...) - -\method{assert_forecast}{default}(forecast, forecast_type = NULL, verbose = TRUE, ...) - \method{assert_forecast}{forecast_binary}(forecast, forecast_type = NULL, verbose = TRUE, ...) \method{assert_forecast}{forecast_point}(forecast, forecast_type = NULL, verbose = TRUE, ...) @@ -20,6 +18,10 @@ assert_forecast(forecast, forecast_type = NULL, verbose = TRUE, ...) \method{assert_forecast}{forecast_quantile}(forecast, forecast_type = NULL, verbose = TRUE, ...) \method{assert_forecast}{forecast_sample}(forecast, forecast_type = NULL, verbose = TRUE, ...) + +assert_forecast(forecast, forecast_type = NULL, verbose = TRUE, ...) + +\method{assert_forecast}{default}(forecast, forecast_type = NULL, verbose = TRUE, ...) } \arguments{ \item{forecast}{A forecast object (a validated data.table with predicted and diff --git a/man/assert_forecast_generic.Rd b/man/assert_forecast_generic.Rd index 89fa6aaae..94e78bfea 100644 --- a/man/assert_forecast_generic.Rd +++ b/man/assert_forecast_generic.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/forecast.R +% Please edit documentation in R/class-forecast.R \name{assert_forecast_generic} \alias{assert_forecast_generic} \title{Validation common to all forecast types} diff --git a/man/assert_forecast_type.Rd b/man/assert_forecast_type.Rd index 884fb07e2..7667303a2 100644 --- a/man/assert_forecast_type.Rd +++ b/man/assert_forecast_type.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/get_-functions.R +% Please edit documentation in R/get-forecast-type.R \name{assert_forecast_type} \alias{assert_forecast_type} \title{Assert that forecast type is as expected} diff --git a/man/assert_input_binary.Rd b/man/assert_input_binary.Rd index df1022f40..ddb01dfc0 100644 --- a/man/assert_input_binary.Rd +++ b/man/assert_input_binary.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/check-inputs-scoring-functions.R +% Please edit documentation in R/metrics-binary.R \name{assert_input_binary} \alias{assert_input_binary} \title{Assert that inputs are correct for binary forecast} diff --git a/man/assert_input_interval.Rd b/man/assert_input_interval.Rd index 7a965d824..a0271c470 100644 --- a/man/assert_input_interval.Rd +++ b/man/assert_input_interval.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/check-inputs-scoring-functions.R +% Please edit documentation in R/metrics-interval-range.R \name{assert_input_interval} \alias{assert_input_interval} \title{Assert that inputs are correct for interval-based forecast} diff --git a/man/assert_input_nominal.Rd b/man/assert_input_nominal.Rd index 539a797e2..4fbb6fa37 100644 --- a/man/assert_input_nominal.Rd +++ b/man/assert_input_nominal.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/check-inputs-scoring-functions.R +% Please edit documentation in R/metrics-nominal.R \name{assert_input_nominal} \alias{assert_input_nominal} \title{Assert that inputs are correct for nominal forecasts} diff --git a/man/assert_input_point.Rd b/man/assert_input_point.Rd index 3b2d63e2f..1ebf7799f 100644 --- a/man/assert_input_point.Rd +++ b/man/assert_input_point.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/check-inputs-scoring-functions.R +% Please edit documentation in R/metrics-point.R \name{assert_input_point} \alias{assert_input_point} \title{Assert that inputs are correct for point forecast} diff --git a/man/assert_input_quantile.Rd b/man/assert_input_quantile.Rd index 23631ed1f..d10761eac 100644 --- a/man/assert_input_quantile.Rd +++ b/man/assert_input_quantile.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/check-inputs-scoring-functions.R +% Please edit documentation in R/metrics-quantile.R \name{assert_input_quantile} \alias{assert_input_quantile} \title{Assert that inputs are correct for quantile-based forecast} diff --git a/man/assert_input_sample.Rd b/man/assert_input_sample.Rd index a7bd09d9c..f39932d23 100644 --- a/man/assert_input_sample.Rd +++ b/man/assert_input_sample.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/check-inputs-scoring-functions.R +% Please edit documentation in R/metrics-sample.R \name{assert_input_sample} \alias{assert_input_sample} \title{Assert that inputs are correct for sample-based forecast} diff --git a/man/validate_scores.Rd b/man/assert_scores.Rd similarity index 79% rename from man/validate_scores.Rd rename to man/assert_scores.Rd index e27da22b8..fb8aae377 100644 --- a/man/validate_scores.Rd +++ b/man/assert_scores.Rd @@ -1,10 +1,10 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/score.R -\name{validate_scores} -\alias{validate_scores} +% Please edit documentation in R/class-scores.R +\name{assert_scores} +\alias{assert_scores} \title{Validate an object of class \code{scores}} \usage{ -validate_scores(scores) +assert_scores(scores) } \arguments{ \item{scores}{A data.table or similar with scores as produced by \code{\link[=score]{score()}}.} diff --git a/man/check_dims_ok_point.Rd b/man/check_dims_ok_point.Rd index 5c96422f2..b2ad4e057 100644 --- a/man/check_dims_ok_point.Rd +++ b/man/check_dims_ok_point.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/check-inputs-scoring-functions.R +% Please edit documentation in R/metrics-point.R \name{check_dims_ok_point} \alias{check_dims_ok_point} \title{Check Inputs Have Matching Dimensions} diff --git a/man/check_duplicates.Rd b/man/check_duplicates.Rd index 7473f1e11..30c1d4c9d 100644 --- a/man/check_duplicates.Rd +++ b/man/check_duplicates.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/check-input-helpers.R +% Please edit documentation in R/get-duplicate-forecasts.R \name{check_duplicates} \alias{check_duplicates} \title{Check that there are no duplicate forecasts} diff --git a/man/check_input_binary.Rd b/man/check_input_binary.Rd index ceed81d1c..2a79630fe 100644 --- a/man/check_input_binary.Rd +++ b/man/check_input_binary.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/check-inputs-scoring-functions.R +% Please edit documentation in R/metrics-binary.R \name{check_input_binary} \alias{check_input_binary} \title{Check that inputs are correct for binary forecast} diff --git a/man/check_input_interval.Rd b/man/check_input_interval.Rd index 7d62f4385..f8734d415 100644 --- a/man/check_input_interval.Rd +++ b/man/check_input_interval.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/check-inputs-scoring-functions.R +% Please edit documentation in R/metrics-interval-range.R \name{check_input_interval} \alias{check_input_interval} \title{Check that inputs are correct for interval-based forecast} diff --git a/man/check_input_point.Rd b/man/check_input_point.Rd index 32a18affe..09b0359db 100644 --- a/man/check_input_point.Rd +++ b/man/check_input_point.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/check-inputs-scoring-functions.R +% Please edit documentation in R/metrics-point.R \name{check_input_point} \alias{check_input_point} \title{Check that inputs are correct for point forecast} diff --git a/man/check_input_quantile.Rd b/man/check_input_quantile.Rd index 70504bb38..74ead30ee 100644 --- a/man/check_input_quantile.Rd +++ b/man/check_input_quantile.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/check-inputs-scoring-functions.R +% Please edit documentation in R/metrics-quantile.R \name{check_input_quantile} \alias{check_input_quantile} \title{Check that inputs are correct for quantile-based forecast} diff --git a/man/check_input_sample.Rd b/man/check_input_sample.Rd index 5b0d48b3c..dbe2efebe 100644 --- a/man/check_input_sample.Rd +++ b/man/check_input_sample.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/check-inputs-scoring-functions.R +% Please edit documentation in R/metrics-sample.R \name{check_input_sample} \alias{check_input_sample} \title{Check that inputs are correct for sample-based forecast} diff --git a/man/check_number_per_forecast.Rd b/man/check_number_per_forecast.Rd index 39901356c..fd7784aaf 100644 --- a/man/check_number_per_forecast.Rd +++ b/man/check_number_per_forecast.Rd @@ -1,8 +1,8 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/check-input-helpers.R +% Please edit documentation in R/class-forecast.R \name{check_number_per_forecast} \alias{check_number_per_forecast} -\title{Check that all forecasts have the same number of quantiles or samples} +\title{Check that all forecasts have the same number of rows} \usage{ check_number_per_forecast(data, forecast_unit) } @@ -16,7 +16,8 @@ Returns TRUE if the check was successful and a string with an error message otherwise. } \description{ -Function checks the number of quantiles or samples per forecast. +Helper function that checks the number of rows (corresponding e.g to +quantiles or samples) per forecast. If the number of quantiles or samples is the same for all forecasts, it returns TRUE and a string with an error message otherwise. } diff --git a/man/clean_forecast.Rd b/man/clean_forecast.Rd index b5158ddde..431f72a52 100644 --- a/man/clean_forecast.Rd +++ b/man/clean_forecast.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/forecast.R +% Please edit documentation in R/class-forecast.R \name{clean_forecast} \alias{clean_forecast} \title{Clean forecast object} diff --git a/man/ensure_data.table.Rd b/man/ensure_data.table.Rd index 6df303271..c6826dede 100644 --- a/man/ensure_data.table.Rd +++ b/man/ensure_data.table.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/utils.R +% Please edit documentation in R/check-input-helpers.R \name{ensure_data.table} \alias{ensure_data.table} \title{Ensure that an object is a \code{data.table}} diff --git a/man/example_binary.Rd b/man/example_binary.Rd index 6c0ede99d..6d2a28ed7 100644 --- a/man/example_binary.Rd +++ b/man/example_binary.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/data.R +% Please edit documentation in R/class-forecast-binary.R \docType{data} \name{example_binary} \alias{example_binary} diff --git a/man/example_nominal.Rd b/man/example_nominal.Rd index 8bfe131c9..e07914840 100644 --- a/man/example_nominal.Rd +++ b/man/example_nominal.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/data.R +% Please edit documentation in R/class-forecast-nominal.R \docType{data} \name{example_nominal} \alias{example_nominal} diff --git a/man/example_point.Rd b/man/example_point.Rd index 3a3542b7b..a8872f0f8 100644 --- a/man/example_point.Rd +++ b/man/example_point.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/data.R +% Please edit documentation in R/class-forecast-point.R \docType{data} \name{example_point} \alias{example_point} diff --git a/man/example_quantile.Rd b/man/example_quantile.Rd index 6456ba744..6bbe1df3b 100644 --- a/man/example_quantile.Rd +++ b/man/example_quantile.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/data.R +% Please edit documentation in R/class-forecast-quantile.R \docType{data} \name{example_quantile} \alias{example_quantile} diff --git a/man/example_sample_continuous.Rd b/man/example_sample_continuous.Rd index 01c7a1f2c..1e9a72273 100644 --- a/man/example_sample_continuous.Rd +++ b/man/example_sample_continuous.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/data.R +% Please edit documentation in R/class-forecast-sample.R \docType{data} \name{example_sample_continuous} \alias{example_sample_continuous} diff --git a/man/example_sample_discrete.Rd b/man/example_sample_discrete.Rd index 46e0be40d..75569fca4 100644 --- a/man/example_sample_discrete.Rd +++ b/man/example_sample_discrete.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/data.R +% Please edit documentation in R/class-forecast-sample.R \docType{data} \name{example_sample_discrete} \alias{example_sample_discrete} diff --git a/man/get_correlations.Rd b/man/get_correlations.Rd index ab6a08116..d83a3ca7a 100644 --- a/man/get_correlations.Rd +++ b/man/get_correlations.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/correlations.R +% Please edit documentation in R/get-correlations.R \name{get_correlations} \alias{get_correlations} \title{Calculate correlation between metrics} diff --git a/man/get_coverage.Rd b/man/get_coverage.Rd index af2301bde..0dcdb036f 100644 --- a/man/get_coverage.Rd +++ b/man/get_coverage.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/get_-functions.R +% Please edit documentation in R/get-coverage.R \name{get_coverage} \alias{get_coverage} \title{Get quantile and interval coverage values for quantile-based forecasts} diff --git a/man/get_duplicate_forecasts.Rd b/man/get_duplicate_forecasts.Rd index 298d0bd5f..b69f67fe5 100644 --- a/man/get_duplicate_forecasts.Rd +++ b/man/get_duplicate_forecasts.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/get_-functions.R +% Please edit documentation in R/get-duplicate-forecasts.R \name{get_duplicate_forecasts} \alias{get_duplicate_forecasts} \title{Find duplicate forecasts} diff --git a/man/get_forecast_counts.Rd b/man/get_forecast_counts.Rd index 5659b8d8c..6a0cc74cd 100644 --- a/man/get_forecast_counts.Rd +++ b/man/get_forecast_counts.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/get_-functions.R +% Please edit documentation in R/get-forecast-counts.R \name{get_forecast_counts} \alias{get_forecast_counts} \title{Count number of available forecasts} diff --git a/man/get_forecast_type.Rd b/man/get_forecast_type.Rd index 57f40d981..43f2e05c2 100644 --- a/man/get_forecast_type.Rd +++ b/man/get_forecast_type.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/get_-functions.R +% Please edit documentation in R/get-forecast-type.R \name{get_forecast_type} \alias{get_forecast_type} \title{Get forecast type from forecast object} diff --git a/man/get_forecast_unit.Rd b/man/get_forecast_unit.Rd index d3552a08b..cf2348e8c 100644 --- a/man/get_forecast_unit.Rd +++ b/man/get_forecast_unit.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/get_-functions.R +% Please edit documentation in R/forecast-unit.R \name{get_forecast_unit} \alias{get_forecast_unit} \title{Get unit of a single forecast} diff --git a/man/get_metrics.Rd b/man/get_metrics.Rd index 5f6dae9b8..14cf188dc 100644 --- a/man/get_metrics.Rd +++ b/man/get_metrics.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/default-scoring-rules.R +% Please edit documentation in R/metrics.R \name{get_metrics} \alias{get_metrics} \title{Get metrics} diff --git a/man/get_metrics.forecast_binary.Rd b/man/get_metrics.forecast_binary.Rd index d345be742..ea7440fb4 100644 --- a/man/get_metrics.forecast_binary.Rd +++ b/man/get_metrics.forecast_binary.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/default-scoring-rules.R +% Please edit documentation in R/class-forecast-binary.R \name{get_metrics.forecast_binary} \alias{get_metrics.forecast_binary} \title{Get default metrics for binary forecasts} diff --git a/man/get_metrics.forecast_nominal.Rd b/man/get_metrics.forecast_nominal.Rd index 06f225862..8d6f556b6 100644 --- a/man/get_metrics.forecast_nominal.Rd +++ b/man/get_metrics.forecast_nominal.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/default-scoring-rules.R +% Please edit documentation in R/class-forecast-nominal.R \name{get_metrics.forecast_nominal} \alias{get_metrics.forecast_nominal} \title{Get default metrics for nominal forecasts} diff --git a/man/get_metrics.forecast_point.Rd b/man/get_metrics.forecast_point.Rd index 43f8143e5..e47d56118 100644 --- a/man/get_metrics.forecast_point.Rd +++ b/man/get_metrics.forecast_point.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/default-scoring-rules.R +% Please edit documentation in R/class-forecast-point.R \name{get_metrics.forecast_point} \alias{get_metrics.forecast_point} \title{Get default metrics for point forecasts} diff --git a/man/get_metrics.forecast_quantile.Rd b/man/get_metrics.forecast_quantile.Rd index 9740cb44d..fc3f248d4 100644 --- a/man/get_metrics.forecast_quantile.Rd +++ b/man/get_metrics.forecast_quantile.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/default-scoring-rules.R +% Please edit documentation in R/class-forecast-quantile.R \name{get_metrics.forecast_quantile} \alias{get_metrics.forecast_quantile} \title{Get default metrics for quantile-based forecasts} diff --git a/man/get_metrics.forecast_sample.Rd b/man/get_metrics.forecast_sample.Rd index 10e7d8c67..e51b65e41 100644 --- a/man/get_metrics.forecast_sample.Rd +++ b/man/get_metrics.forecast_sample.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/default-scoring-rules.R +% Please edit documentation in R/class-forecast-sample.R \name{get_metrics.forecast_sample} \alias{get_metrics.forecast_sample} \title{Get default metrics for sample-based forecasts} diff --git a/man/get_metrics.scores.Rd b/man/get_metrics.scores.Rd index d658ddb07..944371f01 100644 --- a/man/get_metrics.scores.Rd +++ b/man/get_metrics.scores.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/get_-functions.R +% Please edit documentation in R/class-scores.R \name{get_metrics.scores} \alias{get_metrics.scores} \title{Get names of the metrics that were used for scoring} diff --git a/man/get_pit.Rd b/man/get_pit.Rd index 0da5d9b84..bba4a675d 100644 --- a/man/get_pit.Rd +++ b/man/get_pit.Rd @@ -1,19 +1,20 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/pit.R -\name{get_pit} +% Please edit documentation in R/class-forecast-quantile.R, +% R/class-forecast-sample.R, R/get-pit.R +\name{get_pit.forecast_quantile} +\alias{get_pit.forecast_quantile} +\alias{get_pit.forecast_sample} \alias{get_pit} \alias{get_pit.default} -\alias{get_pit.forecast_sample} -\alias{get_pit.forecast_quantile} \title{Probability integral transformation (data.frame version)} \usage{ -get_pit(forecast, by, ...) - -\method{get_pit}{default}(forecast, by, ...) +\method{get_pit}{forecast_quantile}(forecast, by, ...) \method{get_pit}{forecast_sample}(forecast, by, n_replicates = 100, ...) -\method{get_pit}{forecast_quantile}(forecast, by, ...) +get_pit(forecast, by, ...) + +\method{get_pit}{default}(forecast, by, ...) } \arguments{ \item{forecast}{A forecast object (a validated data.table with predicted and diff --git a/man/get_protected_columns.Rd b/man/get_protected_columns.Rd index 1346b33bc..2abf59eeb 100644 --- a/man/get_protected_columns.Rd +++ b/man/get_protected_columns.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/get_-functions.R +% Please edit documentation in R/get-protected-columns.R \name{get_protected_columns} \alias{get_protected_columns} \title{Get protected columns from data} diff --git a/man/get_range_from_quantile.Rd b/man/get_range_from_quantile.Rd index 79552fb6d..08df972b9 100644 --- a/man/get_range_from_quantile.Rd +++ b/man/get_range_from_quantile.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/utils_data_handling.R +% Please edit documentation in R/helper-quantile-interval-range.R \name{get_range_from_quantile} \alias{get_range_from_quantile} \title{Get interval range belonging to a quantile} diff --git a/man/get_type.Rd b/man/get_type.Rd index e34786e5f..48cb90eb4 100644 --- a/man/get_type.Rd +++ b/man/get_type.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/get_-functions.R +% Please edit documentation in R/get-forecast-type.R \name{get_type} \alias{get_type} \title{Get type of a vector or matrix of observed values or predictions} diff --git a/man/interval_score.Rd b/man/interval_score.Rd index 028458674..1e677ac8a 100644 --- a/man/interval_score.Rd +++ b/man/interval_score.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/metrics-range.R +% Please edit documentation in R/metrics-interval-range.R \name{interval_score} \alias{interval_score} \title{Interval score} diff --git a/man/is_forecast.Rd b/man/is_forecast.Rd index ec1b849f5..a52d396bd 100644 --- a/man/is_forecast.Rd +++ b/man/is_forecast.Rd @@ -1,25 +1,27 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/forecast.R -\name{is_forecast} -\alias{is_forecast} -\alias{is_forecast_sample} +% Please edit documentation in R/class-forecast-binary.R, +% R/class-forecast-nominal.R, R/class-forecast-point.R, +% R/class-forecast-quantile.R, R/class-forecast-sample.R, R/class-forecast.R +\name{is_forecast_binary} \alias{is_forecast_binary} +\alias{is_forecast_nominal} \alias{is_forecast_point} \alias{is_forecast_quantile} -\alias{is_forecast_nominal} +\alias{is_forecast_sample} +\alias{is_forecast} \title{Test whether an object is a forecast object} \usage{ -is_forecast(x) - -is_forecast_sample(x) - is_forecast_binary(x) +is_forecast_nominal(x) + is_forecast_point(x) is_forecast_quantile(x) -is_forecast_nominal(x) +is_forecast_sample(x) + +is_forecast(x) } \arguments{ \item{x}{An R object.} diff --git a/man/log_shift.Rd b/man/log_shift.Rd index dc2136b10..e852791b8 100644 --- a/man/log_shift.Rd +++ b/man/log_shift.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/convenience-functions.R +% Please edit documentation in R/transform-forecasts.R \name{log_shift} \alias{log_shift} \title{Log transformation with an additive shift} diff --git a/man/new_forecast.Rd b/man/new_forecast.Rd index d22e48fff..b07c5f15a 100644 --- a/man/new_forecast.Rd +++ b/man/new_forecast.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/forecast.R +% Please edit documentation in R/class-forecast.R \name{new_forecast} \alias{new_forecast} \title{Class constructor for \code{forecast} objects} diff --git a/man/new_scores.Rd b/man/new_scores.Rd index bdf4ca31d..29cf2af7a 100644 --- a/man/new_scores.Rd +++ b/man/new_scores.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/score.R +% Please edit documentation in R/class-scores.R \name{new_scores} \alias{new_scores} \title{Construct an object of class \code{scores}} diff --git a/man/pit_sample.Rd b/man/pit_sample.Rd index c9101bb93..a9e7effff 100644 --- a/man/pit_sample.Rd +++ b/man/pit_sample.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/pit.R +% Please edit documentation in R/metrics-sample.R \name{pit_sample} \alias{pit_sample} \title{Probability integral transformation for counts} diff --git a/man/plot_correlations.Rd b/man/plot_correlations.Rd index 8cf7acd67..1297e4d78 100644 --- a/man/plot_correlations.Rd +++ b/man/plot_correlations.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/plot.R +% Please edit documentation in R/get-correlations.R \name{plot_correlations} \alias{plot_correlations} \title{Plot correlation between metrics} diff --git a/man/plot_forecast_counts.Rd b/man/plot_forecast_counts.Rd index 1f02df602..fb9b891fc 100644 --- a/man/plot_forecast_counts.Rd +++ b/man/plot_forecast_counts.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/plot.R +% Please edit documentation in R/get-forecast-counts.R \name{plot_forecast_counts} \alias{plot_forecast_counts} \title{Visualise the number of available forecasts} diff --git a/man/plot_heatmap.Rd b/man/plot_heatmap.Rd index 335b01152..32a7bb643 100644 --- a/man/plot_heatmap.Rd +++ b/man/plot_heatmap.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/plot.R +% Please edit documentation in R/plot-heatmap.R \name{plot_heatmap} \alias{plot_heatmap} \title{Create a heatmap of a scoring metric} diff --git a/man/plot_interval_coverage.Rd b/man/plot_interval_coverage.Rd index 5514f357c..384b9559a 100644 --- a/man/plot_interval_coverage.Rd +++ b/man/plot_interval_coverage.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/plot.R +% Please edit documentation in R/get-coverage.R \name{plot_interval_coverage} \alias{plot_interval_coverage} \title{Plot interval coverage} diff --git a/man/plot_pairwise_comparisons.Rd b/man/plot_pairwise_comparisons.Rd index fd175c424..35b99f9b8 100644 --- a/man/plot_pairwise_comparisons.Rd +++ b/man/plot_pairwise_comparisons.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/plot.R +% Please edit documentation in R/pairwise-comparisons.R \name{plot_pairwise_comparisons} \alias{plot_pairwise_comparisons} \title{Plot heatmap of pairwise comparisons} diff --git a/man/plot_pit.Rd b/man/plot_pit.Rd index c776fb485..f439e68bc 100644 --- a/man/plot_pit.Rd +++ b/man/plot_pit.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/plot.R +% Please edit documentation in R/get-pit.R \name{plot_pit} \alias{plot_pit} \title{PIT histogram} diff --git a/man/plot_quantile_coverage.Rd b/man/plot_quantile_coverage.Rd index 5caf1ea18..25e15911b 100644 --- a/man/plot_quantile_coverage.Rd +++ b/man/plot_quantile_coverage.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/plot.R +% Please edit documentation in R/get-coverage.R \name{plot_quantile_coverage} \alias{plot_quantile_coverage} \title{Plot quantile coverage} diff --git a/man/plot_wis.Rd b/man/plot_wis.Rd index 9b3fae71e..fa33a808d 100644 --- a/man/plot_wis.Rd +++ b/man/plot_wis.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/plot.R +% Please edit documentation in R/plot-wis.R \name{plot_wis} \alias{plot_wis} \title{Plot contributions to the weighted interval score} diff --git a/man/print.forecast.Rd b/man/print.forecast.Rd index 96d219b11..f5c2c8f78 100644 --- a/man/print.forecast.Rd +++ b/man/print.forecast.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/print.R +% Please edit documentation in R/class-forecast.R \name{print.forecast} \alias{print.forecast} \title{Print information about a forecast object} diff --git a/man/quantile_to_interval.Rd b/man/quantile_to_interval.Rd index 0ee136622..168352dfb 100644 --- a/man/quantile_to_interval.Rd +++ b/man/quantile_to_interval.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/utils_data_handling.R +% Please edit documentation in R/helper-quantile-interval-range.R \name{quantile_to_interval} \alias{quantile_to_interval} \alias{quantile_to_interval_dataframe} diff --git a/man/run_safely.Rd b/man/run_safely.Rd index 525dc2fdd..5a2cdabc6 100644 --- a/man/run_safely.Rd +++ b/man/run_safely.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/utils.R +% Please edit documentation in R/score.R \name{run_safely} \alias{run_safely} \title{Run a function safely} @@ -18,7 +18,7 @@ provide a more informative warning message in case \code{fun} errors.} The result of \code{fun} or \code{NULL} if \code{fun} errors } \description{ -This is a wrapper function designed to run a function safely +This is a wrapper/helper function designed to run a function safely when it is not completely clear what arguments could be passed to the function. diff --git a/man/sample_to_interval_long.Rd b/man/sample_to_interval_long.Rd index c16254a07..58f903d90 100644 --- a/man/sample_to_interval_long.Rd +++ b/man/sample_to_interval_long.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/utils_data_handling.R +% Please edit documentation in R/helper-quantile-interval-range.R \name{sample_to_interval_long} \alias{sample_to_interval_long} \title{Change data from a sample-based format to a long interval range format} diff --git a/man/score.Rd b/man/score.Rd index a94fee252..a3d1b5af7 100644 --- a/man/score.Rd +++ b/man/score.Rd @@ -1,25 +1,27 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/score.R -\name{score} -\alias{score} +% Please edit documentation in R/class-forecast-binary.R, +% R/class-forecast-nominal.R, R/class-forecast-point.R, +% R/class-forecast-quantile.R, R/class-forecast-sample.R, R/score.R +\name{score.forecast_binary} \alias{score.forecast_binary} \alias{score.forecast_nominal} \alias{score.forecast_point} -\alias{score.forecast_sample} \alias{score.forecast_quantile} +\alias{score.forecast_sample} +\alias{score} \title{Evaluate forecasts} \usage{ -score(forecast, metrics, ...) - \method{score}{forecast_binary}(forecast, metrics = get_metrics(forecast), ...) \method{score}{forecast_nominal}(forecast, metrics = get_metrics(forecast), ...) \method{score}{forecast_point}(forecast, metrics = get_metrics(forecast), ...) +\method{score}{forecast_quantile}(forecast, metrics = get_metrics(forecast), ...) + \method{score}{forecast_sample}(forecast, metrics = get_metrics(forecast), ...) -\method{score}{forecast_quantile}(forecast, metrics = get_metrics(forecast), ...) +score(forecast, metrics, ...) } \arguments{ \item{forecast}{A forecast object (a validated data.table with predicted and diff --git a/man/select_metrics.Rd b/man/select_metrics.Rd index 813114b44..ffa127e18 100644 --- a/man/select_metrics.Rd +++ b/man/select_metrics.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/default-scoring-rules.R +% Please edit documentation in R/metrics.R \name{select_metrics} \alias{select_metrics} \title{Select metrics from a list of functions} diff --git a/man/set_forecast_unit.Rd b/man/set_forecast_unit.Rd index d609442ea..7bc42e0c3 100644 --- a/man/set_forecast_unit.Rd +++ b/man/set_forecast_unit.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/convenience-functions.R +% Please edit documentation in R/forecast-unit.R \name{set_forecast_unit} \alias{set_forecast_unit} \title{Set unit of a single forecast manually} diff --git a/man/theme_scoringutils.Rd b/man/theme_scoringutils.Rd index a59c506c5..95b8df7aa 100644 --- a/man/theme_scoringutils.Rd +++ b/man/theme_scoringutils.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/plot.R +% Please edit documentation in R/theme-scoringutils.R \name{theme_scoringutils} \alias{theme_scoringutils} \title{Scoringutils ggplot2 theme} diff --git a/man/transform_forecasts.Rd b/man/transform_forecasts.Rd index e58707f12..fddd040f2 100644 --- a/man/transform_forecasts.Rd +++ b/man/transform_forecasts.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/convenience-functions.R +% Please edit documentation in R/transform-forecasts.R \name{transform_forecasts} \alias{transform_forecasts} \title{Transform forecasts and observed values} diff --git a/man/validate_metrics.Rd b/man/validate_metrics.Rd index 27d02645f..1e2ba4c96 100644 --- a/man/validate_metrics.Rd +++ b/man/validate_metrics.Rd @@ -1,5 +1,5 @@ % Generated by roxygen2: do not edit by hand -% Please edit documentation in R/metrics-validate.R +% Please edit documentation in R/score.R \name{validate_metrics} \alias{validate_metrics} \title{Validate metrics} diff --git a/tests/testthat/_snaps/print.md b/tests/testthat/_snaps/class-forecast.md similarity index 100% rename from tests/testthat/_snaps/print.md rename to tests/testthat/_snaps/class-forecast.md diff --git a/tests/testthat/_snaps/plot_correlation/plot-correlation.svg b/tests/testthat/_snaps/get-correlations/plot-correlation.svg similarity index 100% rename from tests/testthat/_snaps/plot_correlation/plot-correlation.svg rename to tests/testthat/_snaps/get-correlations/plot-correlation.svg diff --git a/tests/testthat/_snaps/plot_interval_coverage/plot-interval-coverage.svg b/tests/testthat/_snaps/get-coverage/plot-interval-coverage.svg similarity index 100% rename from tests/testthat/_snaps/plot_interval_coverage/plot-interval-coverage.svg rename to tests/testthat/_snaps/get-coverage/plot-interval-coverage.svg diff --git a/tests/testthat/_snaps/plot_quantile_coverage/plot-quantile-coverage.svg b/tests/testthat/_snaps/get-coverage/plot-quantile-coverage.svg similarity index 100% rename from tests/testthat/_snaps/plot_quantile_coverage/plot-quantile-coverage.svg rename to tests/testthat/_snaps/get-coverage/plot-quantile-coverage.svg diff --git a/tests/testthat/_snaps/plot_avail_forecasts/plot-available-forecasts.svg b/tests/testthat/_snaps/get-forecast-counts/plot-available-forecasts.svg similarity index 100% rename from tests/testthat/_snaps/plot_avail_forecasts/plot-available-forecasts.svg rename to tests/testthat/_snaps/get-forecast-counts/plot-available-forecasts.svg diff --git a/tests/testthat/_snaps/plot_pit/plot-pit-integer.svg b/tests/testthat/_snaps/get-pit/plot-pit-integer.svg similarity index 100% rename from tests/testthat/_snaps/plot_pit/plot-pit-integer.svg rename to tests/testthat/_snaps/get-pit/plot-pit-integer.svg diff --git a/tests/testthat/_snaps/plot_pit/plot-pit-quantile-2.svg b/tests/testthat/_snaps/get-pit/plot-pit-quantile-2.svg similarity index 100% rename from tests/testthat/_snaps/plot_pit/plot-pit-quantile-2.svg rename to tests/testthat/_snaps/get-pit/plot-pit-quantile-2.svg diff --git a/tests/testthat/_snaps/plot_pit/plot-pit-quantile.svg b/tests/testthat/_snaps/get-pit/plot-pit-quantile.svg similarity index 100% rename from tests/testthat/_snaps/plot_pit/plot-pit-quantile.svg rename to tests/testthat/_snaps/get-pit/plot-pit-quantile.svg diff --git a/tests/testthat/_snaps/plot_pit/plot-pit-sample.svg b/tests/testthat/_snaps/get-pit/plot-pit-sample.svg similarity index 100% rename from tests/testthat/_snaps/plot_pit/plot-pit-sample.svg rename to tests/testthat/_snaps/get-pit/plot-pit-sample.svg diff --git a/tests/testthat/_snaps/utils_data_handling.md b/tests/testthat/_snaps/helper-quantile-interval-range.md similarity index 100% rename from tests/testthat/_snaps/utils_data_handling.md rename to tests/testthat/_snaps/helper-quantile-interval-range.md diff --git a/tests/testthat/_snaps/plot_pairwise_comparison/plot-pairwise-comparison-pval.svg b/tests/testthat/_snaps/pairwise_comparison/plot-pairwise-comparison-pval.svg similarity index 100% rename from tests/testthat/_snaps/plot_pairwise_comparison/plot-pairwise-comparison-pval.svg rename to tests/testthat/_snaps/pairwise_comparison/plot-pairwise-comparison-pval.svg diff --git a/tests/testthat/_snaps/plot_pairwise_comparison/plot-pairwise-comparison.svg b/tests/testthat/_snaps/pairwise_comparison/plot-pairwise-comparison.svg similarity index 100% rename from tests/testthat/_snaps/plot_pairwise_comparison/plot-pairwise-comparison.svg rename to tests/testthat/_snaps/pairwise_comparison/plot-pairwise-comparison.svg diff --git a/tests/testthat/test-check-input-helpers.R b/tests/testthat/test-check-input-helpers.R new file mode 100644 index 000000000..e8ebcb995 --- /dev/null +++ b/tests/testthat/test-check-input-helpers.R @@ -0,0 +1,56 @@ +test_that("Check equal length works if all arguments have length 1", { + out <- interval_score( + observed = 5, + lower = 4, + upper = 6, + interval_range = 95, + weigh = TRUE, + separate_results = FALSE + ) + expect_equal(out, 0.05) +}) + + +test_that("check_columns_present works", { + expect_identical( + capture.output( + check_columns_present(example_binary, c("loc1", "loc2", "loc3")) + ), + paste( + "[1] \"Columns 'loc1', 'loc2', 'loc3' not found in data\"" + ) + ) + expect_identical( + capture.output( + check_columns_present(example_binary, c("loc1")) + ), + paste( + "[1] \"Column 'loc1' not found in data\"" + ) + ) + expect_true( + check_columns_present(example_binary, c("location_name")) + ) + expect_true( + check_columns_present(example_binary, columns = NULL) + ) +}) + +test_that("test_columns_not_present works", { + expect_true( + test_columns_not_present(example_binary, "sample_id") + ) + expect_false( + test_columns_not_present(example_binary, "location") + ) +}) + +test_that("check_columns_present() works", { + expect_equal( + check_columns_present(example_quantile, c("observed", "predicted", "nop")), + "Column 'nop' not found in data" + ) + expect_true( + check_columns_present(example_quantile, c("observed", "predicted")) + ) +}) diff --git a/tests/testthat/test-class-forecast-binary.R b/tests/testthat/test-class-forecast-binary.R new file mode 100644 index 000000000..82a29793b --- /dev/null +++ b/tests/testthat/test-class-forecast-binary.R @@ -0,0 +1,128 @@ +# ============================================================================== +# as_forecast_binary() +# ============================================================================== +test_that("output of as_forecast_binary() is accepted as input to score()", { + check <- suppressMessages(as_forecast_binary(example_binary)) + expect_no_error( + score_check <- score(na.omit(check)) + ) + expect_equal(score_check, suppressMessages(score(as_forecast_binary(example_binary)))) +}) + + +# ============================================================================== +# is_forecast_binary() +# ============================================================================== +test_that("is_forecast_binary() works as expected", { + expect_true(is_forecast_binary(example_binary)) + expect_false(is_forecast_binary(example_point)) + expect_false(is_forecast_binary(example_quantile)) + expect_false(is_forecast_binary(example_sample_continuous)) + expect_false(is_forecast_binary(example_nominal)) +}) + + +# ============================================================================== +# assert_forecast.forecast_binary() +# ============================================================================== +test_that("assert_forecast.forecast_binary works as expected", { + test <- na.omit(as.data.table(example_binary)) + test[, "sample_id" := 1:nrow(test)] + + # error if there is a superfluous sample_id column + expect_error( + as_forecast_binary(test), + "Input looks like a binary forecast, but an additional column called `sample_id` or `quantile` was found." + ) + + # expect error if probabilties are not in [0, 1] + test <- na.omit(as.data.table(example_binary)) + test[, "predicted" := predicted + 1] + expect_error( + as_forecast_binary(test), + "Input looks like a binary forecast, but found the following issue" + ) +}) + + + +# ============================================================================== +# score.forecast_binary() +# ============================================================================== +test_that("function produces output for a binary case", { + + expect_equal( + names(scores_binary), + c(get_forecast_unit(example_binary), names(get_metrics(example_binary))) + ) + + eval <- summarise_scores(scores_binary, by = c("model", "target_type")) + + expect_equal( + nrow(eval) > 1, + TRUE + ) + expect_equal( + colnames(eval), + c( + "model", "target_type", + "brier_score", + "log_score" + ) + ) + + expect_true("brier_score" %in% names(eval)) + + expect_s3_class(eval, c("scores", "data.table", "data.frame"), exact = TRUE) +}) + +test_that("score.forecast_binary() errors with only NA values", { + # [.forecast()` will warn even before score() + only_nas <- suppressWarnings( + copy(example_binary)[, predicted := NA_real_] + ) + expect_error( + score(only_nas), + "After removing rows with NA values in the data, no forecasts are left." + ) +}) + +test_that("score() gives same result for binary as regular function", { + manual_eval <- brier_score( + factor(example_binary$observed), + example_binary$predicted + ) + expect_equal(scores_binary$brier_score, manual_eval[!is.na(manual_eval)]) +}) + +test_that( + "passing additional functions to score binary works handles them", { + test_fun <- function(x, y, ...) { + if (hasArg("test")) { + message("test argument found") + } + return(y) + } + + df <- example_binary[model == "EuroCOVIDhub-ensemble" & + target_type == "Cases" & location == "DE"] %>% + as_forecast_binary() + + # passing a simple function works + expect_equal( + score(df, + metrics = list("identity" = function(x, y) {return(y)}))$identity, + df$predicted + ) + } +) + +# ============================================================================== +# get_metrics.forecast_binary() +# ============================================================================== + +test_that("get_metrics.forecast_binary() works as expected", { + expect_true( + is.list(get_metrics(example_binary)) + ) +}) diff --git a/tests/testthat/test-class-forecast-nominal.R b/tests/testthat/test-class-forecast-nominal.R new file mode 100644 index 000000000..f6c48fccc --- /dev/null +++ b/tests/testthat/test-class-forecast-nominal.R @@ -0,0 +1,55 @@ +# ============================================================================== +# as_forecast_nominal() +# ============================================================================== +test_that("as_forecast.forecast_nominal() works as expected", { + ex <- data.table::copy(example_nominal) %>% + na.omit() + + + expect_s3_class( + as_forecast_nominal(ex), + c("forecast_nominal", "forecast", "data.table", "data.frame"), + exact = TRUE + ) + + setnames(ex, old = "predicted_label", new = "label") + expect_no_condition( + as_forecast_nominal(ex, predicted_label = "label") + ) +}) + +test_that("as_forecast.forecast_nominal() breaks when rows with zero probability are missing", { + ex_faulty <- as.data.table(example_nominal) + ex_faulty <- ex_faulty[predicted != 0] + expect_warning( + expect_error( + as_forecast_nominal(ex_faulty), + "Found incomplete forecasts" + ), + "Some forecasts have different numbers of rows" + ) +}) + + +# ============================================================================== +# is_forecast_nominal() +# ============================================================================== +test_that("is_forecast_nominal() works as expected", { + expect_true(is_forecast_nominal(example_nominal)) + expect_false(is_forecast_nominal(example_binary)) + expect_false(is_forecast_nominal(example_point)) + expect_false(is_forecast_nominal(example_quantile)) + expect_false(is_forecast_nominal(example_sample_continuous)) + expect_false(is_forecast_nominal(1:10)) +}) + + +# ============================================================================== +# get_metrics.forecast_nominal() +# ============================================================================== + +test_that("get_metrics.forecast_nominal() works as expected", { + expect_true( + is.list(get_metrics(example_nominal)) + ) +}) diff --git a/tests/testthat/test-class-forecast-point.R b/tests/testthat/test-class-forecast-point.R new file mode 100644 index 000000000..12666e5db --- /dev/null +++ b/tests/testthat/test-class-forecast-point.R @@ -0,0 +1,109 @@ +# ============================================================================== +# as_forecast_point() +# ============================================================================== + +test_that("as_forecast_point() works", { + expect_no_condition( + as_forecast_point(as_forecast_quantile(na.omit(example_quantile))) + ) +}) + + +# ============================================================================== +# is_forecast_point() +# ============================================================================== +test_that("is_forecast_point() works as expected", { + expect_true(is_forecast_point(example_point)) + expect_false(is_forecast_point(example_binary)) + expect_false(is_forecast_point(example_quantile)) + expect_false(is_forecast_point(example_sample_continuous)) + expect_false(is_forecast_point(example_nominal)) +}) + + +# ============================================================================== +# assert_forecast.forecast_point() +# ============================================================================== + +test_that("assert_forecast.forecast_point() works as expected", { + test <- na.omit(data.table::as.data.table(example_point)) + test <- as_forecast_point(test) + + # expect an error if column is changed to character after initial validation. + expect_warning( + test <- test[, "predicted" := as.character(predicted)], + "Input looks like a point forecast, but found the following issue" + ) + expect_error( + assert_forecast(test), + "Input looks like a point forecast, but found the following issue" + ) +}) + +test_that("assert_forecast.forecast_point() complains if the forecast type is wrong", { + expect_error( + assert_forecast(na.omit(example_point), forecast_type = "quantile"), + "Forecast type determined by scoringutils based on input:" + ) +}) + + +# ============================================================================== +# score.forecast_point() +# ============================================================================== +test_that("function produces output for a point case", { + expect_equal( + names(scores_binary), + c(get_forecast_unit(example_binary), names(get_metrics(example_binary))) + ) + + eval <- summarise_scores(scores_point, by = c("model", "target_type")) + + expect_equal( + nrow(eval) > 1, + TRUE + ) + expect_equal( + colnames(eval), + c("model", "target_type", names(get_metrics(example_point))) + ) + + expect_s3_class(eval, c("scores", "data.table", "data.frame"), exact = TRUE) +}) + +test_that("Changing metrics names works", { + metrics_test <- get_metrics(example_point) + names(metrics_test)[1] = "just_testing" + eval <- suppressMessages(score(as_forecast_point(example_point), + metrics = metrics_test)) + eval_summarised <- summarise_scores(eval, by = "model") + expect_equal( + colnames(eval_summarised), + c("model", "just_testing", names(get_metrics(example_point))[-1]) + ) +}) + +test_that("score.forecast_point() errors with only NA values", { + # [.forecast()` will warn even before score() + only_nas <- suppressWarnings( + copy(example_point)[, predicted := NA_real_] + ) + expect_error( + score(only_nas), + "After removing rows with NA values in the data, no forecasts are left." + ) +}) + +# ============================================================================== +# get_metrics.forecast_point() +# ============================================================================== +test_that("get_metrics.forecast_point() works as expected", { + expect_true( + is.list(get_metrics(example_point)) + ) + + expect_equal( + get_metrics.scores(scores_point), + c("ae_point", "se_point", "ape") + ) +}) diff --git a/tests/testthat/test-class-forecast-quantile.R b/tests/testthat/test-class-forecast-quantile.R new file mode 100644 index 000000000..87d746bfd --- /dev/null +++ b/tests/testthat/test-class-forecast-quantile.R @@ -0,0 +1,362 @@ +# ============================================================================== +# as_forecast_quantile() +# ============================================================================== + +test_that("as_forecast_quantile() works as expected", { + test <- na.omit(data.table::copy(example_quantile)) + + expect_s3_class( + as_forecast_quantile(test), + c("forecast_quantile", "forecast", "data.table", "data.frame"), + exact = TRUE + ) + + # expect error when arguments are not correct + expect_error(as_forecast_quantile(test, observed = 3), "Must be of type 'character'") + expect_error(as_forecast_quantile(test, quantile_level = c("1", "2")), "Must have length 1") + expect_error(as_forecast_quantile(test, observed = "missing"), "Must be a subset of") + + # expect no condition with columns already present + expect_no_condition( + as_forecast_quantile(test, + observed = "observed", predicted = "predicted", + forecast_unit = c( + "location", "model", "target_type", + "target_end_date", "horizon" + ), + quantile_level = "quantile_level" + ) + ) +}) + +test_that("as_forecast_quantile() function works", { + check <- suppressMessages(as_forecast_quantile(example_quantile)) + expect_s3_class(check, "forecast_quantile") +}) + +test_that("as_forecast_quantile() errors if there is both a sample_id and a quantile_level column", { + example <- as.data.table(example_quantile)[, sample_id := 1] + expect_error( + as_forecast_quantile(example), + "Found columns `quantile_level` and `sample_id`. Only one of these is allowed" + ) +}) + +test_that("as_forecast_quantile() warns if there are different numbers of quantiles", { + example <- as.data.table(example_quantile)[-1000, ] + expect_warning( + w <- as_forecast_quantile(na.omit(example)), + "Some forecasts have different numbers of rows" + ) + # printing should work without a warning because printing is silent + expect_no_condition(w) +}) + +test_that("as_forecast_quantile() function throws an error with duplicate forecasts", { + example <- rbind( + example_quantile, + example_quantile[1000:1010] + ) + + expect_error( + suppressMessages(suppressWarnings(as_forecast_quantile(example))), + "Assertion on 'data' failed: There are instances with more than one forecast for the same target. This can't be right and needs to be resolved. Maybe you need to check the unit of a single forecast and add missing columns? Use the function get_duplicate_forecasts() to identify duplicate rows.", # nolint + fixed = TRUE + ) +}) + +test_that("as_forecast_quantile() function throws an error when no predictions or observed values are present", { + expect_error( + suppressMessages(suppressWarnings(as_forecast_quantile( + data.table::copy(example_quantile)[, predicted := NULL] + ))), + "Assertion on 'data' failed: Column 'predicted' not found in data." + ) + + expect_error( + suppressMessages(suppressWarnings(as_forecast_quantile( + data.table::copy(example_quantile)[, observed := NULL] + ))), + "Assertion on 'data' failed: Column 'observed' not found in data." + ) + + expect_error( + suppressMessages(suppressWarnings(as_forecast_quantile( + data.table::copy(example_quantile)[, c("observed", "predicted") := NULL] + ))), + "Assertion on 'data' failed: Columns 'observed', 'predicted' not found in data." + ) +}) + +test_that("as_forecast_quantile() works with a data.frame", { + expect_no_condition(as_forecast_quantile(example_quantile_df)) +}) + +test_that("as_forecast_quantiles works", { + samples <- data.frame( + date = as.Date("2020-01-01") + 1:10, + model = "model1", + observed = 1:10, + predicted = c(rep(0, 10), 2:11, 3:12, 4:13, rep(100, 10)), + sample_id = rep(1:5, each = 10) + ) %>% + as_forecast_sample() + + quantile <- data.frame( + date = rep(as.Date("2020-01-01") + 1:10, each = 2), + model = "model1", + observed = rep(1:10, each = 2), + quantile_level = c(0.25, 0.75), + predicted = rep(2:11, each = 2) + c(0, 2) + ) + + expect_no_condition( + as_forecast_quantile(samples, probs = c(0.25, 0.75)) + ) + + wrongclass <- as_forecast_sample(samples) + class(wrongclass) <- c("forecast_point", "data.table", "data.frame") + expect_error( + as_forecast_quantile(wrongclass, quantile_level = c(0.25, 0.75)), + "Assertion on 'quantile_level' failed: Must be of type" + ) + + + quantile2 <- as_forecast_quantile( + as_forecast_sample(samples), + probs = c(0.25, 0.75) + ) + + expect_equal(quantile, as.data.frame(quantile2)) + + # Verify that `type` is correctly scoped in as_forecast_quantile(), as it is + # also an argument. + # If it's not scoped well, the call to `as_forecast_quantile()` will fail. + samples$type <- "test" + + quantile3 <- as_forecast_quantile( + as_forecast_sample(samples), + probs = c(0.25, 0.75) + ) + quantile3$type <- NULL + + expect_identical( + quantile2, + quantile3 + ) +}) + +test_that("as_forecast_quantiles issue 557 fix", { + out <- example_sample_discrete %>% + na.omit() %>% + as_forecast_quantile( + probs = c(0.01, 0.025, seq(0.05, 0.95, 0.05), 0.975, 0.99) + ) %>% + score() + + expect_equal(any(is.na(out$interval_coverage_deviation)), FALSE) +}) + + +# ============================================================================== +# is_forecast_quantile() +# ============================================================================== +test_that("is_forecast_quantile() works as expected", { + expect_true(is_forecast_quantile(example_quantile)) + expect_false(is_forecast_quantile(example_binary)) + expect_false(is_forecast_quantile(example_point)) + expect_false(is_forecast_quantile(example_sample_continuous)) + expect_false(is_forecast_quantile(example_nominal)) +}) + +# ============================================================================== +# score.forecast_quantile() +# ============================================================================== +test_that("score_quantile correctly handles separate results = FALSE", { + df <- example_quantile[model == "EuroCOVIDhub-ensemble" & + target_type == "Cases" & location == "DE"] + metrics <- get_metrics(example_quantile) + metrics$wis <- purrr::partial(wis, separate_results = FALSE) + eval <- score(df[!is.na(predicted)], metrics = metrics) + + expect_equal( + nrow(eval) > 1, + TRUE + ) + expect_true(all(names(get_metrics(example_quantile)) %in% colnames(eval))) + + expect_s3_class(eval, c("scores", "data.table", "data.frame"), exact = TRUE) +}) + + +test_that("score() quantile produces desired metrics", { + data <- data.frame( + observed = rep(1:10, each = 3), + predicted = rep(c(-0.3, 0, 0.3), 10) + rep(1:10, each = 3), + model = "Model 1", + date = as.Date("2020-01-01") + rep(1:10, each = 3), + quantile_level = rep(c(0.1, 0.5, 0.9), times = 10) + ) + + data <-suppressWarnings(suppressMessages(as_forecast_quantile(data))) + + out <- score(forecast = data, metrics = metrics_no_cov) + metrics <- c( + "dispersion", "underprediction", "overprediction", + "bias", "ae_median" + ) + + expect_true(all(metrics %in% colnames(out))) +}) + + +test_that("calculation of ae_median is correct for a quantile format case", { + eval <- summarise_scores(scores_quantile,by = "model") + + example <- as.data.table(example_quantile) + ae <- example[quantile_level == 0.5, ae := abs(observed - predicted)][!is.na(model), .(mean = mean(ae, na.rm = TRUE)), + by = "model" + ]$mean + + expect_equal(sort(eval$ae_median), sort(ae)) +}) + + +test_that("all quantile and range formats yield the same result", { + eval1 <- summarise_scores(scores_quantile, by = "model") + + df <- as.data.table(example_quantile) + + ae <- df[ + quantile_level == 0.5, ae := abs(observed - predicted)][ + !is.na(model), .(mean = mean(ae, na.rm = TRUE)), + by = "model" + ]$mean + + expect_equal(sort(eval1$ae_median), sort(ae)) +}) + +test_that("WIS is the same with other metrics omitted or included", { + eval <- score(example_quantile, + metrics = list("wis" = wis) + ) + + eval2 <- scores_quantile + + expect_equal( + sum(eval$wis), + sum(eval2$wis) + ) +}) + + +test_that("score.forecast_quantile() errors with only NA values", { + # [.forecast()` will warn even before score() + only_nas <- suppressWarnings( + copy(example_quantile)[, predicted := NA_real_] + ) + expect_error( + score(only_nas), + "After removing rows with NA values in the data, no forecasts are left." + ) +}) + +test_that("score.forecast_quantile() works as expected in edge cases", { + # only the median + onlymedian <- example_quantile[quantile_level == 0.5] + expect_no_condition( + s <- score(onlymedian, metrics = get_metrics( + example_quantile, + exclude = c("interval_coverage_50", "interval_coverage_90") + )) + ) + expect_equal( + s$wis, abs(onlymedian$observed - onlymedian$predicted) + ) + + # only one symmetric interval is present + oneinterval <- example_quantile[quantile_level %in% c(0.25,0.75)] %>% + as_forecast_quantile() + expect_message( + s <- score( + oneinterval, + metrics = get_metrics( + example_quantile, + exclude = c("interval_coverage_90", "ae_median") + ) + ), + "Median not available" + ) +}) + +test_that("score() works even if only some quantiles are missing", { + + # only the median is there + onlymedian <- example_quantile[quantile_level == 0.5] + expect_no_condition( + score(onlymedian, metrics = get_metrics( + example_quantile, + exclude = c("interval_coverage_50", "interval_coverage_90") + )) + ) + + + # asymmetric intervals + asymm <- example_quantile[!quantile_level > 0.6] + expect_warning( + expect_warning( + score_a <- score(asymm) %>% summarise_scores(by = "model"), + "Computation for `interval_coverage_50` failed." + ), + "Computation for `interval_coverage_90` failed." + ) + + # check that the result is equal to a case where we discard the entire + # interval in terms of WIS + inner <- example_quantile[quantile_level %in% c(0.4, 0.45, 0.5, 0.55, 0.6)] + score_b <- score(inner, get_metrics( + inner, exclude = c("interval_coverage_50", "interval_coverage_90") + )) %>% + summarise_scores(by = "model") + expect_equal( + score_a$wis, + score_b$wis + ) + + # median is not there, but only in a single model + test <- data.table::copy(example_quantile) + test_no_median <- test[model == "epiforecasts-EpiNow2" & !(quantile_level %in% c(0.5)), ] + test <- rbind(test[model != "epiforecasts-EpiNow2"], test_no_median) + + test <- suppressWarnings(as_forecast_quantile(test)) + expect_message( + expect_warning( + score(test), + "Computation for `ae_median` failed." + ), + "interpolating median from the two innermost quantiles" + ) +}) + +# ============================================================================== +# get_metrics.forecast_quantile() +# ============================================================================== +test_that("get_metrics.forecast_quantile() works as expected", { + expect_true( + is.list(get_metrics(example_quantile)) + ) +}) + + +# ============================================================================== +# get_pit.forecast_quantile() +# ============================================================================== +test_that("get_pit.forecast_quantile() works as expected", { + pit_quantile <- get_pit(example_quantile, by = "model") + + expect_equal(names(pit_quantile), c("model", "quantile_level", "pit_value")) + expect_s3_class(pit_quantile, c("data.table", "data.frame"), exact = TRUE) + + # check printing works + expect_output(print(pit_quantile)) +}) diff --git a/tests/testthat/test-class-forecast-sample.R b/tests/testthat/test-class-forecast-sample.R new file mode 100644 index 000000000..5a5715e31 --- /dev/null +++ b/tests/testthat/test-class-forecast-sample.R @@ -0,0 +1,76 @@ +# ============================================================================== +# as_forecast_sample() +# ============================================================================== +test_that("as_forecast_sample() works as expected", { + test <- na.omit(data.table::copy(example_sample_continuous)) + data.table::setnames(test, + old = c("observed", "predicted", "sample_id"), + new = c("obs", "pred", "sample") + ) + expect_no_condition( + as_forecast_sample(test, + observed = "obs", predicted = "pred", + forecast_unit = c( + "location", "model", "target_type", + "target_end_date", "horizon" + ), + sample_id = "sample" + ) + ) +}) + +test_that("Running `as_forecast_sample()` twice returns the same object", { + ex <- na.omit(example_sample_continuous) + + expect_identical( + as_forecast_sample(as_forecast_sample(ex)), + as_forecast_sample(ex) + ) +}) + + +# ============================================================================== +# is_forecast_sample() +# ============================================================================== +test_that("is_forecast_sample() works as expected", { + expect_true(is_forecast_sample(example_sample_continuous)) + expect_false(is_forecast_sample(example_binary)) + expect_false(is_forecast_sample(example_point)) + expect_false(is_forecast_sample(example_quantile)) + expect_false(is_forecast_sample(example_nominal)) + expect_false(is_forecast_sample(1:10)) +}) + + +# ============================================================================== +# get_metrics.forecast_sample() +# ============================================================================== + +test_that("get_metrics.forecast_sample() works as expected", { + expect_true( + is.list(get_metrics(example_sample_continuous)) + ) + expect_true( + is.list(get_metrics(example_sample_discrete)) + ) +}) + + +# ============================================================================== +# get_pit.forecast_sample() +# ============================================================================== +test_that("get_pit.forecast_sample() works as expected", { + pit_continuous <- get_pit(example_sample_continuous, by = c("model", "target_type")) + pit_integer <- get_pit(example_sample_discrete, by = c("model", "location")) + + expect_equal(names(pit_continuous), c("model", "target_type", "pit_value")) + expect_equal(names(pit_integer), c("model", "location", "pit_value")) + + # check printing works + expect_output(print(pit_continuous)) + expect_output(print(pit_integer)) + + # check class is correct + expect_s3_class(pit_continuous, c("data.table", "data.frame"), exact = TRUE) + expect_s3_class(pit_integer, c("data.table", "data.frame"), exact = TRUE) +}) diff --git a/tests/testthat/test-class-forecast.R b/tests/testthat/test-class-forecast.R new file mode 100644 index 000000000..3054c3fe6 --- /dev/null +++ b/tests/testthat/test-class-forecast.R @@ -0,0 +1,218 @@ +# ============================================================================== +# as_forecast() +# ============================================================================== +# see tests for each forecast type for more specific tests. + + +# ============================================================================== +# is_forecast() +# ============================================================================== + +test_that("is_forecast() works as expected", { + expect_true(is_forecast(example_binary)) + expect_true(is_forecast(example_point)) + expect_true(is_forecast(example_quantile)) + expect_true(is_forecast(example_sample_continuous)) + expect_true(is_forecast(example_nominal)) + + expect_false(is_forecast(1:10)) + expect_false(is_forecast(data.table::as.data.table(example_point))) +}) + + +# ============================================================================== +# assert_forecast() and assert_forecast_generic() +# ============================================================================== + +test_that("assert_forecast() works as expected", { + # test that by default, `as_forecast()` errors + expect_error( + assert_forecast(data.frame(x = 1:10)), + "The input needs to be a valid forecast object." + ) +}) + +test_that("assert_forecast_generic() works as expected with a data.frame", { + expect_error( + assert_forecast_generic(example_quantile_df), + "Assertion on 'data' failed: Must be a data.table, not data.frame." + ) +}) + + +# ============================================================================== +# new_forecast() +# ============================================================================== + +test_that("new_forecast() works as expected with a data.frame", { + expect_s3_class( + new_forecast(example_quantile_df, "quantile"), + c("forecast_quantile", "data.table", "data.frame") + ) +}) + + +# ============================================================================== +# [.forecast() +# ============================================================================== + +test_that("[.forecast() immediately invalidates on change when necessary", { + test <- na.omit(data.table::copy(example_quantile)) + + # For cols; various ways to drop. + # We use local() to avoid actual deletion in this frame and having to recreate + # the input multiple times + expect_warning( + local(test[, colnames(test) != "observed", with = FALSE]), + "Error in validating" + ) + + expect_warning( + local(test[, "observed"] <- NULL), + "Error in validating" + ) + + expect_warning( + local(test$observed <- NULL), + "Error in validating" + ) + + expect_warning( + local(test[["observed"]] <- NULL), + "Error in validating" + ) + + # For rows + expect_warning( + local(test[2, ] <- test[1, ]) + ) +}) + +test_that("[.forecast() doesn't warn on cases where the user likely didn't intend getting a forecast object", { + test <- as_forecast_quantile(na.omit(example_quantile)) + + expect_no_condition(test[, location]) +}) + +test_that("[.forecast() is compatible with data.table syntax", { + test <- as_forecast_quantile(na.omit(example_quantile)) + + expect_no_condition( + test[location == "DE"] + ) + + expect_no_condition( + test[ + target_type == "Cases", + .(location, target_end_date, observed, location_name, forecast_date, quantile_level, predicted, model) + ] + ) +}) + + +# ============================================================================== +# print.forecast() +# ============================================================================== +test_that("print() works on forecast_* objects", { + # Check print works on each forecast object + test_dat <- list( + example_binary, example_quantile, + example_point, example_sample_continuous, + example_sample_discrete + ) + test_dat <- lapply(test_dat, na.omit) + for (dat in test_dat) { + forecast_type <- scoringutils:::get_forecast_type(dat) + forecast_unit <- get_forecast_unit(dat) + + fn_name <- paste0("as_forecast_", forecast_type) + fn <- get(fn_name) + dat <- suppressWarnings(suppressMessages(do.call(fn, list(dat)))) + + # Check Forecast type + expect_snapshot(print(dat)) + expect_snapshot(print(dat)) + # Check Forecast unit + expect_snapshot(print(dat)) + expect_snapshot(print(dat)) + + # Check print.data.table works. + output_original <- suppressMessages(capture.output(print(dat))) + output_test <- suppressMessages(capture.output(print(data.table(dat)))) + expect_contains(output_original, output_test) + } +}) + +test_that("print() throws the expected messages", { + test <- data.table::copy(example_point) + class(test) <- c("point", "forecast", "data.table", "data.frame") + suppressMessages( + expect_message( + capture.output(print(test)), + "Could not determine forecast type due to error in validation." + ) + ) + + class(test) <- c("forecast_point", "forecast") + suppressMessages( + expect_message( + capture.output(print(test)), + "Could not determine forecast unit." + ) + ) +}) + + +# ============================================================================== +# check_number_per_forecast() +# ============================================================================== +test_that("check_number_per_forecast works", { + expect_identical( + capture.output( + check_number_per_forecast( + example_binary, + forecast_unit = "location_name" + ) + ), + paste( + "[1] \"Some forecasts have different numbers of rows", + "(e.g. quantiles or samples). scoringutils found: 224, 215.", + "This may be a problem (it can potentially distort scores,", + "making it more difficult to compare them),", + "so make sure this is intended.\"" + ) + ) + expect_true( + check_number_per_forecast( + example_binary + ) + ) +}) + + +# ============================================================================== +# Test removing `NA` values from the data +# ============================================================================== +test_that("removing NA rows from data works as expected", { + expect_equal(nrow(na.omit(example_quantile)), 20401) + + ex <- data.frame(observed = c(NA, 1:3), predicted = 1:4) + expect_equal(nrow(na.omit(ex)), 3) + + ex$predicted <- c(1:3, NA) + expect_equal(nrow(na.omit(ex)), 2) + + # test that attributes and classes are retained + ex <- as_forecast_sample(na.omit(example_sample_discrete)) + expect_s3_class( + na.omit(ex), + c("forecast_sample", "forecast", "data.table", "data.frame"), + exact = TRUE + ) + + attributes <- attributes(ex) + expect_equal( + attributes(na.omit(ex)), + attributes + ) +}) diff --git a/tests/testthat/test-class-scores.R b/tests/testthat/test-class-scores.R new file mode 100644 index 000000000..443f9f6cd --- /dev/null +++ b/tests/testthat/test-class-scores.R @@ -0,0 +1,34 @@ +# ============================================================================== +# get_metrics.scores() +# ============================================================================== +test_that("get_metrics.scores() works as expected", { + expect_null( + get_metrics.scores(as.data.frame(as.matrix(scores_point))) + ) + + expect_true( + "brier_score" %in% get_metrics.scores(scores_binary) + ) + + expect_equal( + get_metrics.scores(scores_sample_continuous), + attr(scores_sample_continuous, "metrics") + ) + + # check that function errors if `error = TRUE` and not otherwise + expect_error( + get_metrics.scores(example_quantile, error = TRUE), + "Input needs an attribute" + ) + expect_no_condition( + get_metrics.scores(scores_sample_continuous) + ) + + # expect warning if some column changed + ex <- data.table::copy(scores_sample_continuous) + data.table::setnames(ex, old = "crps", new = "changed") + expect_warning( + get_metrics.scores(ex), + "scores have been previously computed, but are no longer column names" + ) +}) diff --git a/tests/testthat/test-customise_metric.R b/tests/testthat/test-customise_metric.R deleted file mode 100644 index e69de29bb..000000000 diff --git a/tests/testthat/test-forecast-unit.R b/tests/testthat/test-forecast-unit.R new file mode 100644 index 000000000..e83ba5b98 --- /dev/null +++ b/tests/testthat/test-forecast-unit.R @@ -0,0 +1,98 @@ +# ============================================================================ # +# `set_forecast_unit()` +# ============================================================================ # + +test_that("function set_forecast_unit() works", { + # some columns in the example data have duplicated information. So we can remove + # these and see whether the result stays the same. + scores1 <- scores_quantile[order(location, target_end_date, target_type, horizon, model), ] + + # test that if setting the forecast unit results in an invalid object, + # a warning occurs. + expect_warning( + set_forecast_unit(example_quantile, "model"), + "Assertion on 'data' failed: There are instances with more" + ) + + ex2 <- set_forecast_unit( + example_quantile, + c("location", "target_end_date", "target_type", "horizon", "model") + ) + scores2 <- score(na.omit(ex2)) + scores2 <- scores2[order(location, target_end_date, target_type, horizon, model), ] + + expect_equal(scores1$interval_score, scores2$interval_score) +}) + +test_that("set_forecast_unit() works on input that's not a data.table", { + df <- data.frame( + a = 1:2, + b = 2:3, + c = 3:4 + ) + expect_equal( + colnames(set_forecast_unit(df, c("a", "b"))), + c("a", "b") + ) + + expect_equal( + names(set_forecast_unit(as.matrix(df), "a")), + "a" + ) + + expect_s3_class( + set_forecast_unit(df, c("a", "b")), + c("data.table", "data.frame"), + exact = TRUE + ) +}) + +test_that("set_forecast_unit() revalidates a forecast object", { + obj <- as_forecast_quantile(na.omit(example_quantile)) + expect_no_condition( + set_forecast_unit(obj, c("location", "target_end_date", "target_type", "model", "horizon")) + ) +}) + + +test_that("function set_forecast_unit() errors when column is not there", { + expect_error( + set_forecast_unit( + example_quantile, + c("location", "target_end_date", "target_type", "horizon", "model", "test1", "test2") + ), + "Assertion on 'forecast_unit' failed: Must be a subset of " + ) +}) + +test_that("function get_forecast_unit() and set_forecast_unit() work together", { + fu_set <- c("location", "target_end_date", "target_type", "horizon", "model") + ex <- set_forecast_unit(example_binary, fu_set) + fu_get <- get_forecast_unit(ex) + expect_equal(fu_set, fu_get) +}) + +test_that("output class of set_forecast_unit() is as expected", { + ex <- as_forecast_binary(na.omit(example_binary)) + expect_equal( + class(ex), + class(set_forecast_unit(ex, c("location", "target_end_date", "target_type", "horizon", "model"))) + ) +}) + + +# ============================================================================== +# `get_forecast_unit()` +# ============================================================================== +test_that("get_forecast_unit() works as expected", { + fc <- c( + "location", "target_end_date", "target_type", "location_name", + "forecast_date", "model", "horizon" + ) + + expect_equal(get_forecast_unit(example_quantile), fc) + expect_equal(get_forecast_unit(scores_quantile), fc) + + # test with data.frame + expect_equal(get_forecast_unit(as.data.frame(example_quantile)), fc) +}) diff --git a/tests/testthat/test-forecast.R b/tests/testthat/test-forecast.R deleted file mode 100644 index b462d20b3..000000000 --- a/tests/testthat/test-forecast.R +++ /dev/null @@ -1,340 +0,0 @@ -# ============================================================================== -# as_forecast() -# ============================================================================== - -test_that("Running `as_forecast_sample()` twice returns the same object", { - ex <- na.omit(example_sample_continuous) - - expect_identical( - as_forecast_sample(as_forecast_sample(ex)), - as_forecast_sample(ex) - ) -}) - -test_that("as_forecast works with a data.frame", { - expect_no_condition(as_forecast_quantile(example_quantile_df)) -}) - -test_that("as_forecast() works as expected", { - test <- na.omit(data.table::copy(example_quantile)) - - expect_s3_class( - as_forecast_quantile(test), - c("forecast_quantile", "forecast", "data.table", "data.frame"), - exact = TRUE) - - # expect error when arguments are not correct - expect_error(as_forecast_quantile(test, observed = 3), "Must be of type 'character'") - expect_error(as_forecast_quantile(test, quantile_level = c("1", "2")), "Must have length 1") - expect_error(as_forecast_quantile(test, observed = "missing"), "Must be a subset of") - - # expect no condition with columns already present - expect_no_condition( - as_forecast_quantile(test, - observed = "observed", predicted = "predicted", - forecast_unit = c( - "location", "model", "target_type", - "target_end_date", "horizon" - ), - quantile_level = "quantile_level" - ) - ) - - # additional test with renaming the model column - test <- na.omit(data.table::copy(example_sample_continuous)) - data.table::setnames(test, - old = c("observed", "predicted", "sample_id"), - new = c("obs", "pred", "sample") - ) - expect_no_condition( - as_forecast_sample(test, - observed = "obs", predicted = "pred", - forecast_unit = c( - "location", "model", "target_type", - "target_end_date", "horizon" - ), - sample_id = "sample" - ) - ) -}) - -test_that("as_forecast() function works", { - check <- suppressMessages(as_forecast_quantile(example_quantile)) - expect_s3_class(check, "forecast_quantile") -}) - -test_that("as_forecast() function has an error for empty data.frame", { - d <- data.frame(observed = numeric(), predicted = numeric(), model = character()) - - expect_error( - as_forecast_point(d), - "Assertion on 'data' failed: Must have at least 1 rows, but has 0 rows." - ) -}) - -test_that("as_forecast() errors if there is both a sample_id and a quantile_level column", { - example <- as.data.table(example_quantile)[, sample_id := 1] - expect_error( - as_forecast_quantile(example), - "Found columns `quantile_level` and `sample_id`. Only one of these is allowed" - ) -}) - -test_that("as_forecast() warns if there are different numbers of quantiles", { - example <- as.data.table(example_quantile)[-1000, ] - expect_warning( - w <- as_forecast_quantile(na.omit(example)), - "Some forecasts have different numbers of rows" - ) - # printing should work without a warning because printing is silent - expect_no_condition(w) -}) - -test_that("as_forecast_point() works", { - expect_no_condition( - as_forecast_point(as_forecast_quantile(na.omit(example_quantile))) - ) -}) - -test_that("check_columns_present() works", { - expect_equal( - check_columns_present(example_quantile, c("observed", "predicted", "nop")), - "Column 'nop' not found in data" - ) - expect_true( - check_columns_present(example_quantile, c("observed", "predicted")) - ) -}) - -test_that("check_duplicates() works", { - bad <- rbind( - example_quantile[1000:1010], - example_quantile[1000:1010] - ) - - expect_equal(scoringutils:::check_duplicates(bad), - "There are instances with more than one forecast for the same target. This can't be right and needs to be resolved. Maybe you need to check the unit of a single forecast and add missing columns? Use the function get_duplicate_forecasts() to identify duplicate rows" - ) -}) - -test_that("as_forecast() function throws an error with duplicate forecasts", { - example <- rbind(example_quantile, - example_quantile[1000:1010]) - - expect_error( - suppressMessages(suppressWarnings(as_forecast_quantile(example))), - "Assertion on 'data' failed: There are instances with more than one forecast for the same target. This can't be right and needs to be resolved. Maybe you need to check the unit of a single forecast and add missing columns? Use the function get_duplicate_forecasts() to identify duplicate rows.", #nolint - fixed = TRUE - ) -}) - -test_that("as_forecast_quantile() function throws an error when no predictions or observed values are present", { - expect_error(suppressMessages(suppressWarnings(as_forecast_quantile( - data.table::copy(example_quantile)[, predicted := NULL] - ))), - "Assertion on 'data' failed: Column 'predicted' not found in data.") - - expect_error(suppressMessages(suppressWarnings(as_forecast_quantile( - data.table::copy(example_quantile)[, observed := NULL] - ))), - "Assertion on 'data' failed: Column 'observed' not found in data.") - - expect_error(suppressMessages(suppressWarnings(as_forecast_quantile( - data.table::copy(example_quantile)[, c("observed", "predicted") := NULL] - ))), - "Assertion on 'data' failed: Columns 'observed', 'predicted' not found in data.") -}) - - -test_that("output of as_forecasts() is accepted as input to score()", { - check <- suppressMessages(as_forecast_binary(example_binary)) - expect_no_error( - score_check <- score(na.omit(check)) - ) - expect_equal(score_check, suppressMessages(score(as_forecast_binary(example_binary)))) -}) - - -# as_forecast.forecast_nominal() ----------------------------------------------- -test_that("as_forecast.forecast_nominal() works as expected", { - expect_s3_class( - suppressMessages(as_forecast_nominal(example_nominal)), - c("forecast_nominal", "forecast", "data.table", "data.frame"), - exact = TRUE - ) - - ex <- data.table::copy(example_nominal) %>% - na.omit() - setnames(ex, old = "predicted_label", new = "label") - - expect_no_condition( - as_forecast_nominal(ex, predicted_label = "label") - ) -}) - -test_that("as_forecast.forecast_nominal() breaks when rows with zero probability are missing", { - ex_faulty <- as.data.table(example_nominal) - ex_faulty <- ex_faulty[predicted != 0] - expect_warning( - expect_error( - as_forecast_nominal(ex_faulty), - "Found incomplete forecasts" - ), - "Some forecasts have different numbers of rows" - ) -}) - - -# ============================================================================== -# is_forecast() -# ============================================================================== - -test_that("is_forecast() works as expected", { - ex_binary <- suppressMessages(as_forecast_binary(example_binary)) - ex_point <- suppressMessages(as_forecast_point(example_point)) - ex_quantile <- suppressMessages(as_forecast_quantile(example_quantile)) - ex_continuous <- suppressMessages(as_forecast_sample(example_sample_continuous)) - ex_nominal <- suppressMessages(as_forecast_nominal(example_nominal)) - - expect_true(is_forecast(ex_binary)) - expect_true(is_forecast_binary(ex_binary)) - expect_true(is_forecast_point(ex_point)) - expect_true(is_forecast_quantile(ex_quantile)) - expect_true(is_forecast(ex_continuous)) - expect_true(is_forecast_nominal(ex_nominal)) - - expect_false(is_forecast(1:10)) - expect_false(is_forecast(data.table::as.data.table(example_point))) - expect_false(is_forecast_sample(ex_quantile)) - expect_false(is_forecast_quantile(ex_binary)) -}) - - -# ============================================================================== -# assert_forecast() -# ============================================================================== - -test_that("assert_forecast() works as expected", { - # test that by default, `as_forecast()` errors - expect_error(assert_forecast(data.frame(x = 1:10)), - "The input needs to be a valid forecast object.") -}) - -test_that("assert_forecast.forecast_binary works as expected", { - test <- na.omit(as.data.table(example_binary)) - test[, "sample_id" := 1:nrow(test)] - - # error if there is a superfluous sample_id column - expect_error( - as_forecast_binary(test), - "Input looks like a binary forecast, but an additional column called `sample_id` or `quantile` was found." - ) - - # expect error if probabilties are not in [0, 1] - test <- na.omit(as.data.table(example_binary)) - test[, "predicted" := predicted + 1] - expect_error( - as_forecast_binary(test), - "Input looks like a binary forecast, but found the following issue" - ) -}) - -test_that("assert_forecast.forecast_point() works as expected", { - test <- na.omit(as.data.table(example_point)) - test <- as_forecast_point(test) - - # expect an error if column is changed to character after initial validation. - expect_warning( - test <- test[, "predicted" := as.character(predicted)], - "Input looks like a point forecast, but found the following issue" - ) - expect_error( - assert_forecast(test), - "Input looks like a point forecast, but found the following issue" - ) -}) - -test_that("assert_forecast() complains if the forecast type is wrong", { - test <- na.omit(data.table::copy(example_point)) - test <- as_forecast_point(test) - expect_error( - assert_forecast(test, forecast_type = "quantile"), - "Forecast type determined by scoringutils based on input:" - ) -}) - -test_that("assert_forecast_generic() works as expected with a data.frame", { - expect_error( - assert_forecast_generic(example_quantile_df), - "Assertion on 'data' failed: Must be a data.table, not data.frame." - ) -}) - - -# ============================================================================== -# new_forecast() -# ============================================================================== - -test_that("new_forecast() works as expected with a data.frame", { - expect_s3_class( - new_forecast(example_quantile_df, "quantile"), - c("forecast_quantile", "data.table", "data.frame") - ) -}) - -# ============================================================================== -# [.forecast() -# ============================================================================== - -test_that("[.forecast() immediately invalidates on change when necessary", { - test <- as_forecast_quantile(na.omit(example_quantile)) - - # For cols; various ways to drop. - # We use local() to avoid actual deletion in this frame and having to recreate - # the input multiple times - expect_warning( - local(test[, colnames(test) != "observed", with = FALSE]), - "Error in validating" - ) - - expect_warning( - local(test[, "observed"] <- NULL), - "Error in validating" - ) - - expect_warning( - local(test$observed <- NULL), - "Error in validating" - ) - - expect_warning( - local(test[["observed"]] <- NULL), - "Error in validating" - ) - - # For rows - expect_warning( - local(test[2, ] <- test[1, ]) - ) -}) - -test_that("[.forecast() doesn't warn on cases where the user likely didn't intend getting a forecast object", { - test <- as_forecast_quantile(na.omit(example_quantile)) - - expect_no_condition(test[, location]) -}) - -test_that("[.forecast() is compatible with data.table syntax", { - - test <- as_forecast_quantile(na.omit(example_quantile)) - - expect_no_condition( - test[location == "DE"] - ) - - expect_no_condition( - test[target_type == "Cases", - .(location, target_end_date, observed, location_name, forecast_date, quantile_level, predicted, model)] - ) - -}) diff --git a/tests/testthat/test-get_correlations.R b/tests/testthat/test-get-correlations.R similarity index 58% rename from tests/testthat/test-get_correlations.R rename to tests/testthat/test-get-correlations.R index 71715a523..83f769cdd 100644 --- a/tests/testthat/test-get_correlations.R +++ b/tests/testthat/test-get-correlations.R @@ -3,7 +3,6 @@ test_that("get_correlations() works as expected", { # expect all to go well in the usual case expect_no_condition( correlations <- scores_quantile %>% - summarise_scores(by = get_forecast_unit(scores_quantile)) %>% get_correlations() ) expect_equal( @@ -35,3 +34,25 @@ test_that("get_correlations() works as expected", { "Assertion on 'metrics' failed: Must be a subset of" ) }) + +# ============================================================================== +# plot_correlation() +# ============================================================================== +test_that("plot_correlations() works as expected", { + correlations <- get_correlations( + summarise_scores( + scores_quantile, + by = get_forecast_unit(scores_quantile) + ) + ) + p <- plot_correlations(correlations, digits = 2) + expect_s3_class(p, "ggplot") + skip_on_cran() + vdiffr::expect_doppelganger("plot__correlation", p) + + # expect an error if you forgot to compute correlations + expect_error( + plot_correlations(summarise_scores(scores_quantile)), + "Did you forget to call `scoringutils::get_correlations()`?" + ) +}) diff --git a/tests/testthat/test-get-coverage.R b/tests/testthat/test-get-coverage.R new file mode 100644 index 000000000..0b0a57175 --- /dev/null +++ b/tests/testthat/test-get-coverage.R @@ -0,0 +1,92 @@ +# ============================================================================== +# `get_coverage()` +# ============================================================================== +ex_coverage <- example_quantile[model == "EuroCOVIDhub-ensemble"] + +test_that("get_coverage() works as expected", { + cov <- example_quantile %>% + get_coverage(by = get_forecast_unit(example_quantile)) + + expect_equal( + sort(colnames(cov)), + sort(c(get_forecast_unit(example_quantile), c( + "interval_range", "quantile_level", "interval_coverage", "interval_coverage_deviation", + "quantile_coverage", "quantile_coverage_deviation" + ))) + ) + + expect_equal(nrow(cov), nrow(na.omit(example_quantile))) + + expect_s3_class( + cov, + c("data.table", "data.frame"), + exact = TRUE + ) +}) + +test_that("get_coverage() outputs an object of class c('data.table', 'data.frame'", { + cov <- get_coverage(example_quantile) + expect_s3_class(cov, c("data.table", "data.frame"), exact = TRUE) +}) + +test_that("get_coverage() can deal with non-symmetric prediction intervals", { + # the expected result is that `get_coverage()` just works. However, + # all interval coverages with missing values should just be `NA` + test <- data.table::copy(example_quantile) + test <- test[!quantile_level %in% c(0.2, 0.3, 0.5)] + + expect_no_condition(cov <- get_coverage(test)) + + prediction_intervals <- get_range_from_quantile(c(0.2, 0.3, 0.5)) + + missing <- cov[interval_range %in% prediction_intervals] + not_missing <- cov[!interval_range %in% prediction_intervals] + + expect_true(all(is.na(missing$interval_coverage))) + expect_false(any(is.na(not_missing))) + + # test for a version where values are not missing, but just `NA` + # since `get_coverage()` calls `na.omit`, the result should be the same. + test <- data.table::copy(example_quantile) + test <- test[quantile_level %in% c(0.2, 0.3, 0.5), predicted := NA] + cov2 <- get_coverage(test) + expect_equal(cov, cov2) +}) + + +# ============================================================================== +# plot_interval_coverage() +# ============================================================================== +test_that("plot_interval_coverage() works as expected", { + coverage <- example_quantile %>% + na.omit() %>% + as_forecast_quantile() %>% + get_coverage(by = c("model")) + p <- plot_interval_coverage(coverage) + expect_s3_class(p, "ggplot") + skip_on_cran() + suppressWarnings(vdiffr::expect_doppelganger("plot_interval_coverage", p)) + + # make sure that plot_interval_coverage() doesn't drop column names + expect_true(all(c( + "interval_coverage", "interval_coverage_deviation", + "quantile_coverage", "quantile_coverage_deviation" + ) %in% + names(coverage))) +}) + + +# ============================================================================== +# plot_quantile_coverage() +# ============================================================================== +test_that("plot_quantile_coverage() works as expected", { + coverage <- example_quantile %>% + na.omit() %>% + as_forecast_quantile() %>% + get_coverage(by = c("model", "quantile_level")) + + p <- plot_quantile_coverage(coverage) + expect_s3_class(p, "ggplot") + skip_on_cran() + suppressWarnings(vdiffr::expect_doppelganger("plot_quantile_coverage", p)) +}) diff --git a/tests/testthat/test-get-duplicate-forecasts.R b/tests/testthat/test-get-duplicate-forecasts.R new file mode 100644 index 000000000..68e7c9151 --- /dev/null +++ b/tests/testthat/test-get-duplicate-forecasts.R @@ -0,0 +1,112 @@ +# ============================================================================== +# get_duplicate_forecasts() +# ============================================================================== +test_that("get_duplicate_forecasts() works as expected for quantile", { + expect_no_condition(get_duplicate_forecasts( + example_quantile, + forecast_unit = + c("location", "target_end_date", "target_type", "location_name", + "forecast_date", "model") + ) + ) + + expect_equal(nrow(get_duplicate_forecasts(example_quantile)), 0) + expect_equal( + nrow( + get_duplicate_forecasts(rbind(example_quantile, example_quantile[1000:1010]))), + 22 + ) +}) + +test_that("get_duplicate_forecasts() works as expected for sample", { + expect_equal(nrow(get_duplicate_forecasts(example_sample_continuous)), 0) + expect_equal( + nrow( + get_duplicate_forecasts(rbind(example_sample_continuous, example_sample_continuous[1040:1050]))), + 22 + ) +}) + + +test_that("get_duplicate_forecasts() works as expected for binary", { + expect_equal(nrow(get_duplicate_forecasts(example_binary)), 0) + expect_equal( + nrow( + get_duplicate_forecasts(rbind(example_binary, example_binary[1000:1010]))), + 22 + ) +}) + +test_that("get_duplicate_forecasts() works as expected for point", { + expect_equal(nrow(get_duplicate_forecasts(example_binary)), 0) + expect_equal( + nrow( + get_duplicate_forecasts(rbind(example_point, example_point[1010:1020]))), + 22 + ) + + expect_s3_class( + get_duplicate_forecasts(as.data.frame(example_point)), + c("data.table", "data.frame"), + exact = TRUE + ) +}) + +test_that("get_duplicate_forecasts() returns the expected class", { + expect_equal( + class(get_duplicate_forecasts(example_point)), + c("data.table", "data.frame") + ) +}) + +test_that("get_duplicate_forecasts() works as expected with a data.frame", { + duplicates <- get_duplicate_forecasts( + rbind(example_quantile_df, example_quantile_df[101:110, ]) + ) + expect_equal(nrow(duplicates), 20) +}) + +test_that("get_duplicate_forecasts() shows counts correctly", { + duplicates <- get_duplicate_forecasts( + rbind(example_quantile, example_quantile[101:110, ]), + counts = TRUE + ) + expect_equal(nrow(duplicates), 2) + expect_equal(unique(duplicates$n_duplicates), 10) +}) + + +# ============================================================================== +# check_duplicates() +# ============================================================================== +test_that("check_duplicates works", { + example_bin <- rbind(example_binary[1000:1002, ], example_binary[1000:1002, ]) + expect_identical( + capture.output( + check_duplicates(example_bin) + ), + paste( + "[1] \"There are instances with more than one forecast for the same", + "target. This can't be right and needs to be resolved.", + "Maybe you need to check the unit of a single forecast and add", + "missing columns? Use the function get_duplicate_forecasts() to", + "identify duplicate rows\"" + ) + ) + expect_true( + check_duplicates(example_binary) + ) +}) + + +test_that("check_duplicates() works", { + bad <- rbind( + example_quantile[1000:1010], + example_quantile[1000:1010] + ) + + expect_equal(check_duplicates(bad), + "There are instances with more than one forecast for the same target. This can't be right and needs to be resolved. Maybe you need to check the unit of a single forecast and add missing columns? Use the function get_duplicate_forecasts() to identify duplicate rows" + ) +}) + diff --git a/tests/testthat/test-get-forecast-counts.R b/tests/testthat/test-get-forecast-counts.R new file mode 100644 index 000000000..ee0abea57 --- /dev/null +++ b/tests/testthat/test-get-forecast-counts.R @@ -0,0 +1,65 @@ +# ============================================================================== +# `get_forecast_counts()` +# ============================================================================== +test_that("get_forecast_counts() works as expected", { + af <- data.table::copy(example_quantile) + af <- get_forecast_counts( + af, + by = c("model", "target_type", "target_end_date") + ) + + expect_type(af, "list") + expect_type(af$target_type, "character") + expect_type(af$`count`, "integer") + expect_equal(nrow(af[is.na(`count`)]), 0) + af <- example_quantile %>% + get_forecast_counts(by = "model") + expect_equal(nrow(af), 4) + expect_equal(af$`count`, c(256, 256, 128, 247)) + + # Ensure the returning object class is exactly same as a data.table. + expect_s3_class(af, c("data.table", "data.frame"), exact = TRUE) + + # Setting `collapse = c()` means that all quantiles and samples are counted + af <- example_quantile %>% + get_forecast_counts(by = "model", collapse = c()) + expect_equal(nrow(af), 4) + expect_equal(af$`count`, c(5888, 5888, 2944, 5681)) + + af <- example_quantile %>% + get_forecast_counts() + expect_equal(nrow(af), 50688) + + expect_error( + get_forecast_counts(example_quantile, by = NULL), + "Assertion on 'by' failed: Must be a subset of" + ) + + # check whether collapsing also works for model-based forecasts + af <- example_sample_discrete %>% + get_forecast_counts(by = "model") + expect_equal(nrow(af), 4) + + af <- example_sample_discrete %>% + get_forecast_counts(by = "model", collapse = c()) + expect_equal(af$count, c(10240, 10240, 5120, 9880)) +}) + + +# ============================================================================== +# `plot_forecast_counts()` +# ============================================================================== +test_that("plot_forecast_counts() works as expected", { + available_forecasts <- na.omit(example_quantile) %>% + as_forecast_quantile() %>% + get_forecast_counts( + by = c("model", "target_type", "target_end_date") + ) + p <- plot_forecast_counts(available_forecasts, + x = "target_end_date", show_counts = FALSE + ) + + facet_wrap("target_type") + expect_s3_class(p, "ggplot") + skip_on_cran() + vdiffr::expect_doppelganger("plot_available_forecasts", p) +}) diff --git a/tests/testthat/test-get-forecast-type.R b/tests/testthat/test-get-forecast-type.R new file mode 100644 index 000000000..51deb1348 --- /dev/null +++ b/tests/testthat/test-get-forecast-type.R @@ -0,0 +1,101 @@ +# ============================================================================== +# `get_forecast_type` +# ============================================================================== +test_that("get_forecast_type() works as expected", { + expect_equal(get_forecast_type(example_quantile), "quantile") + expect_equal(get_forecast_type(example_sample_continuous), "sample") + expect_equal(get_forecast_type(example_sample_discrete), "sample") + expect_equal(get_forecast_type(example_binary), "binary") + expect_equal(get_forecast_type(example_point), "point") + expect_equal(get_forecast_type(example_nominal), "nominal") + + expect_error( + get_forecast_type(data.frame(x = 1:10)), + "Input is not a valid forecast object", + fixed = TRUE + ) + + test <- test <- data.table::copy(example_quantile) + class(test) <- c("forecast", "data.table", "data.frame") + expect_error( + get_forecast_type(test), + "Input is not a valid forecast object", + ) +}) + + +# ============================================================================== +# `get_type()` +# ============================================================================== +test_that("get_type() works as expected with vectors", { + expect_equal(get_type(1:3), "integer") + expect_equal(get_type(factor(1:2)), "classification") + expect_equal(get_type(c(1.0, 2)), "integer") + expect_equal(get_type(c(1.0, 2.3)), "continuous") + expect_error( + get_type(c("a", "b")), + "Assertion on 'as.vector(x)' failed: Must be of type 'numeric', not 'character'.", + fixed = TRUE + ) +}) + +test_that("get_type() works as expected with matrices", { + expect_equal(get_type(matrix(1:4, nrow = 2)), "integer") + expect_equal(get_type(matrix(c(1.0, 2:4))), "integer") + expect_equal(get_type(matrix(c(1.0, 2.3, 3, 4))), "continuous") + + # matrix of factors doesn't work + expect_error( + get_type(matrix(factor(1:4), nrow = 2)), + "Assertion on 'as.vector(x)' failed: Must be of type 'numeric', not 'character'.", + fixed = TRUE + ) + + expect_error( + get_type(matrix(c("a", "b", "c", "d"))), + "Assertion on 'as.vector(x)' failed: Must be of type 'numeric', not 'character'.", + fixed = TRUE + ) +}) + + +test_that("new `get_type()` is equal to old `prediction_type()", { + get_prediction_type <- function(data) { + if (is.data.frame(data)) { + data <- data$predicted + } + if ( + isTRUE(all.equal(as.vector(data), as.integer(data))) && + !all(is.na(as.integer(data))) + ) { + return("integer") + } else if (suppressWarnings(!all(is.na(as.numeric(data))))) { + return("continuous") + } else { + stop("Input is not numeric and cannot be coerced to numeric") + } + } + + check_data <- list( + 1:2, + # factor(1:2) # old function would classify as "continuous" + c(1.0, 2), + c(1.0, 2.3), + matrix(1:4, nrow = 2), + matrix(c(1.0, 2:4)), + matrix(c(1.0, 2.3, 3, 4)) + ) + + for (i in seq_along(check_data)) { + expect_equal( + get_prediction_type(check_data[[i]]), + get_type(check_data[[i]]) + ) + } +}) + +test_that("get_type() handles `NA` values", { + expect_equal(get_type(c(1, NA, 3)), "integer") + expect_equal(get_type(c(1, NA, 3.2)), "continuous") + expect_error(get_type(NA), "Can't get type: all values of are \"NA\"") +}) diff --git a/tests/testthat/test-get-pit.R b/tests/testthat/test-get-pit.R new file mode 100644 index 000000000..86aea7b5c --- /dev/null +++ b/tests/testthat/test-get-pit.R @@ -0,0 +1,40 @@ +# ============================================================================== +# plot_pit() +# ============================================================================== +test_that("plot_pit() works as expected with quantile forecasts", { + pit <- example_quantile %>% + na.omit() %>% + as_forecast_quantile() %>% + get_pit(by = "model") + p <- plot_pit(pit, breaks = seq(0.1, 1, 0.1)) + expect_s3_class(p, "ggplot") + skip_on_cran() + vdiffr::expect_doppelganger("plot_pit_quantile", p) + + p2 <- plot_pit(pit) + expect_s3_class(p2, "ggplot") + skip_on_cran() + vdiffr::expect_doppelganger("plot_pit_quantile_2", p2) +}) + +test_that("plot_pit() works as expected with integer forecasts", { + set.seed(587) + pit <- example_sample_discrete %>% + na.omit() %>% + as_forecast_sample() %>% + get_pit(by = "model") + p <- plot_pit(pit) + expect_s3_class(p, "ggplot") + skip_on_cran() + vdiffr::expect_doppelganger("plot_pit_integer", p) +}) + +test_that("plot_pit() works as expected with sample forecasts", { + observed <- rnorm(30, mean = 1:30) + predicted <- replicate(200, rnorm(n = 30, mean = 1:30)) + pit <- pit_sample(observed, predicted) + p <- plot_pit(pit) + expect_s3_class(p, "ggplot") + skip_on_cran() + vdiffr::expect_doppelganger("plot_pit_sample", p) +}) diff --git a/tests/testthat/test-utils.R b/tests/testthat/test-get-protected-columns.R similarity index 54% rename from tests/testthat/test-utils.R rename to tests/testthat/test-get-protected-columns.R index ca88ef4da..cd933bf3c 100644 --- a/tests/testthat/test-utils.R +++ b/tests/testthat/test-get-protected-columns.R @@ -1,5 +1,20 @@ -test_that("get_protected_columns() returns the correct result", { +# ============================================================================== +# `get_protected_columns()` +# ============================================================================== +test_that("get_protected_columns() works as expected", { + expect_equal( + scoringutils:::get_protected_columns(), + c( + "predicted", "observed", "sample_id", + "quantile_level", "upper", "lower", "pit_value", + "interval_range", "boundary", "predicted_label", "interval_coverage", + "interval_coverage_deviation", "quantile_coverage", + "quantile_coverage_deviation" + ) + ) +}) +test_that("get_protected_columns() returns the correct result", { data <- example_quantile manual <- protected_columns <- c( "predicted", "observed", "sample_id", "quantile_level", "upper", "lower", @@ -34,47 +49,3 @@ test_that("get_protected_columns() returns the correct result", { auto <- get_protected_columns(data) expect_equal(sort(manual), sort(auto)) }) - - -test_that("run_safely() works as expected", { - f <- function(x) {x} - expect_equal(run_safely(2, fun = f), 2) - expect_equal(run_safely(2, y = 3, fun = f), 2) - expect_warning( - run_safely(fun = f, metric_name = "f"), - 'Computation for `f` failed. Error: argument "x" is missing, with no default', - fixed = TRUE - ) - expect_equal(suppressWarnings(run_safely(y = 3, fun = f, metric_name = "f")), NULL) -}) - - -# ============================================================================== -# get metrics -# ============================================================================== - -test_that("get_metrics.scores() works as expected", { - expect_true( - "brier_score" %in% get_metrics.scores(scores_binary) - ) - - expect_equal(get_metrics.scores(scores_sample_continuous), - attr(scores_sample_continuous, "metrics")) - - #check that function errors if `error = TRUE` and not otherwise - expect_error( - get_metrics.scores(example_quantile, error = TRUE), - "Input needs an attribute" - ) - expect_no_condition( - get_metrics.scores(scores_sample_continuous) - ) - - # expect warning if some column changed - ex <- data.table::copy(scores_sample_continuous) - data.table::setnames(ex, old = "crps", new = "changed") - expect_warning( - get_metrics.scores(ex), - "scores have been previously computed, but are no longer column names" - ) -}) diff --git a/tests/testthat/test-get_-functions.R b/tests/testthat/test-get_-functions.R deleted file mode 100644 index 4b12ab5ed..000000000 --- a/tests/testthat/test-get_-functions.R +++ /dev/null @@ -1,357 +0,0 @@ -# ============================================================================== -# `get_forecast_type` -# ============================================================================== -test_that("get_forecast_type() works as expected", { - expect_equal(get_forecast_type(example_quantile), "quantile") - expect_equal(get_forecast_type(example_sample_continuous), "sample") - expect_equal(get_forecast_type(example_sample_discrete), "sample") - expect_equal(get_forecast_type(example_binary), "binary") - expect_equal(get_forecast_type(example_point), "point") - expect_equal(get_forecast_type(example_nominal), "nominal") - - expect_error( - get_forecast_type(data.frame(x = 1:10)), - "Input is not a valid forecast object", - fixed = TRUE - ) - - test <- test <- data.table::copy(example_quantile) - class(test) <- c("forecast", "data.table", "data.frame") - expect_error( - get_forecast_type(test), - "Input is not a valid forecast object", - ) -}) - - -# ============================================================================== -# get_metrics() -# ============================================================================== -test_that("get_metrics() works as expected", { - expect_equal( - get_metrics.scores(scores_point), - c("ae_point", "se_point", "ape") - ) - - expect_null( - get_metrics.scores(as.data.frame(as.matrix(scores_point))) - ) -}) - - -# ============================================================================== -# `get_forecast_unit()` -# ============================================================================== -test_that("get_forecast_unit() works as expected", { - fc <- c( - "location", "target_end_date", "target_type", "location_name", - "forecast_date", "model", "horizon" - ) - - expect_equal(get_forecast_unit(example_quantile), fc) - expect_equal(get_forecast_unit(scores_quantile), fc) - - # test with data.frame - expect_equal(get_forecast_unit(as.data.frame(example_quantile)), fc) -}) - - -# ============================================================================== -# Test removing `NA` values from the data -# ============================================================================== -test_that("removing NA rows from data works as expected", { - expect_equal(nrow(na.omit(example_quantile)), 20401) - - ex <- data.frame(observed = c(NA, 1:3), predicted = 1:4) - expect_equal(nrow(na.omit(ex)), 3) - - ex$predicted <- c(1:3, NA) - expect_equal(nrow(na.omit(ex)), 2) - - # test that attributes and classes are retained - ex <- as_forecast_sample(na.omit(example_sample_discrete)) - expect_s3_class( - na.omit(ex), - c("forecast_sample", "forecast", "data.table", "data.frame"), - exact = TRUE - ) - - attributes <- attributes(ex) - expect_equal( - attributes(na.omit(ex)), - attributes - ) -}) - - -# ============================================================================== -# `get_type()` -# ============================================================================== -test_that("get_type() works as expected with vectors", { - expect_equal(get_type(1:3), "integer") - expect_equal(get_type(factor(1:2)), "classification") - expect_equal(get_type(c(1.0, 2)), "integer") - expect_equal(get_type(c(1.0, 2.3)), "continuous") - expect_error( - get_type(c("a", "b")), - "Assertion on 'as.vector(x)' failed: Must be of type 'numeric', not 'character'.", - fixed = TRUE - ) -}) - -test_that("get_type() works as expected with matrices", { - expect_equal(get_type(matrix(1:4, nrow = 2)), "integer") - expect_equal(get_type(matrix(c(1.0, 2:4))), "integer") - expect_equal(get_type(matrix(c(1.0, 2.3, 3, 4))), "continuous") - - # matrix of factors doesn't work - expect_error( - get_type(matrix(factor(1:4), nrow = 2)), - "Assertion on 'as.vector(x)' failed: Must be of type 'numeric', not 'character'.", - fixed = TRUE - ) - - expect_error( - get_type(matrix(c("a", "b", "c", "d"))), - "Assertion on 'as.vector(x)' failed: Must be of type 'numeric', not 'character'.", - fixed = TRUE - ) -}) - - -test_that("new `get_type()` is equal to old `prediction_type()", { - get_prediction_type <- function(data) { - if (is.data.frame(data)) { - data <- data$predicted - } - if ( - isTRUE(all.equal(as.vector(data), as.integer(data))) && - !all(is.na(as.integer(data))) - ) { - return("integer") - } else if (suppressWarnings(!all(is.na(as.numeric(data))))) { - return("continuous") - } else { - stop("Input is not numeric and cannot be coerced to numeric") - } - } - - check_data <- list( - 1:2, - # factor(1:2) # old function would classify as "continuous" - c(1.0, 2), - c(1.0, 2.3), - matrix(1:4, nrow = 2), - matrix(c(1.0, 2:4)), - matrix(c(1.0, 2.3, 3, 4)) - ) - - for (i in seq_along(check_data)) { - expect_equal( - get_prediction_type(check_data[[i]]), - get_type(check_data[[i]]) - ) - } -}) - -test_that("get_type() handles `NA` values", { - expect_equal(get_type(c(1, NA, 3)), "integer") - expect_equal(get_type(c(1, NA, 3.2)), "continuous") - expect_error(get_type(NA), "Can't get type: all values of are \"NA\"") -}) - - -# ============================================================================== -# get_duplicate_forecasts() -# ============================================================================== -test_that("get_duplicate_forecasts() works as expected for quantile", { - expect_no_condition(get_duplicate_forecasts( - example_quantile, - forecast_unit = - c("location", "target_end_date", "target_type", "location_name", - "forecast_date", "model") - ) - ) - - expect_equal(nrow(get_duplicate_forecasts(example_quantile)), 0) - expect_equal( - nrow( - get_duplicate_forecasts(rbind(example_quantile, example_quantile[1000:1010]))), - 22 - ) -}) - -test_that("get_duplicate_forecasts() works as expected for sample", { - expect_equal(nrow(get_duplicate_forecasts(example_sample_continuous)), 0) - expect_equal( - nrow( - get_duplicate_forecasts(rbind(example_sample_continuous, example_sample_continuous[1040:1050]))), - 22 - ) -}) - - -test_that("get_duplicate_forecasts() works as expected for binary", { - expect_equal(nrow(get_duplicate_forecasts(example_binary)), 0) - expect_equal( - nrow( - get_duplicate_forecasts(rbind(example_binary, example_binary[1000:1010]))), - 22 - ) -}) - -test_that("get_duplicate_forecasts() works as expected for point", { - expect_equal(nrow(get_duplicate_forecasts(example_binary)), 0) - expect_equal( - nrow( - get_duplicate_forecasts(rbind(example_point, example_point[1010:1020]))), - 22 - ) - - expect_s3_class( - get_duplicate_forecasts(as.data.frame(example_point)), - c("data.table", "data.frame"), - exact = TRUE - ) -}) - -test_that("get_duplicate_forecasts() returns the expected class", { - expect_equal( - class(get_duplicate_forecasts(example_point)), - c("data.table", "data.frame") - ) -}) - -test_that("get_duplicate_forecasts() works as expected with a data.frame", { - duplicates <- get_duplicate_forecasts( - rbind(example_quantile_df, example_quantile_df[101:110, ]) - ) - expect_equal(nrow(duplicates), 20) -}) - -test_that("get_duplicate_forecasts() shows counts correctly", { - duplicates <- get_duplicate_forecasts( - rbind(example_quantile, example_quantile[101:110, ]), - counts = TRUE - ) - expect_equal(nrow(duplicates), 2) - expect_equal(unique(duplicates$n_duplicates), 10) -}) - - -# ============================================================================== -# `get_coverage()` -# ============================================================================== -ex_coverage <- example_quantile[model == "EuroCOVIDhub-ensemble"] - -test_that("get_coverage() works as expected", { - cov <- example_quantile %>% - get_coverage(by = get_forecast_unit(example_quantile)) - - expect_equal( - sort(colnames(cov)), - sort(c(get_forecast_unit(example_quantile), c( - "interval_range", "quantile_level", "interval_coverage", "interval_coverage_deviation", - "quantile_coverage", "quantile_coverage_deviation" - ))) - ) - - expect_equal(nrow(cov), nrow(na.omit(example_quantile))) - - expect_s3_class( - cov, - c("data.table", "data.frame"), - exact = TRUE - ) -}) - -test_that("get_coverage() outputs an object of class c('data.table', 'data.frame'", { - cov <- get_coverage(example_quantile) - expect_s3_class(cov, c("data.table", "data.frame"), exact = TRUE) -}) - -test_that("get_coverage() can deal with non-symmetric prediction intervals", { - # the expected result is that `get_coverage()` just works. However, - # all interval coverages with missing values should just be `NA` - test <- data.table::copy(example_quantile) - test <- test[!quantile_level %in% c(0.2, 0.3, 0.5)] - - expect_no_condition(cov <- get_coverage(test)) - - prediction_intervals <- get_range_from_quantile(c(0.2, 0.3, 0.5)) - - missing <- cov[interval_range %in% prediction_intervals] - not_missing <- cov[!interval_range %in% prediction_intervals] - - expect_true(all(is.na(missing$interval_coverage))) - expect_false(any(is.na(not_missing))) - - # test for a version where values are not missing, but just `NA` - # since `get_coverage()` calls `na.omit`, the result should be the same. - test <- data.table::copy(example_quantile) - test <- test[quantile_level %in% c(0.2, 0.3, 0.5), predicted := NA] - cov2 <- get_coverage(test) - expect_equal(cov, cov2) -}) - - -# ============================================================================== -# `get_protected_columns()` -# ============================================================================== -test_that("get_protected_columns() works as expected", { - expect_equal( - scoringutils:::get_protected_columns(), - c("predicted", "observed", "sample_id", - "quantile_level", "upper", "lower", "pit_value", - "interval_range", "boundary", "predicted_label", "interval_coverage", - "interval_coverage_deviation", "quantile_coverage", - "quantile_coverage_deviation") - ) -}) - -# ============================================================================== -# `get_forecast_counts()` -# ============================================================================== -test_that("get_forecast_counts() works as expected", { - af <- data.table::copy(example_quantile) - af <- get_forecast_counts( - af, - by = c("model", "target_type", "target_end_date") - ) - - expect_type(af, "list") - expect_type(af$target_type, "character") - expect_type(af$`count`, "integer") - expect_equal(nrow(af[is.na(`count`)]), 0) - af <- example_quantile %>% - get_forecast_counts(by = "model") - expect_equal(nrow(af), 4) - expect_equal(af$`count`, c(256, 256, 128, 247)) - - # Ensure the returning object class is exactly same as a data.table. - expect_s3_class(af, c("data.table", "data.frame"), exact = TRUE) - - # Setting `collapse = c()` means that all quantiles and samples are counted - af <- example_quantile %>% - get_forecast_counts(by = "model", collapse = c()) - expect_equal(nrow(af), 4) - expect_equal(af$`count`, c(5888, 5888, 2944, 5681)) - - af <- example_quantile %>% - get_forecast_counts() - expect_equal(nrow(af), 50688) - - expect_error( - get_forecast_counts(example_quantile, by = NULL), - "Assertion on 'by' failed: Must be a subset of" - ) - - # check whether collapsing also works for model-based forecasts - af <- example_sample_discrete %>% - get_forecast_counts(by = "model") - expect_equal(nrow(af), 4) - - af <- example_sample_discrete %>% - get_forecast_counts(by = "model", collapse = c()) - expect_equal(af$count, c(10240, 10240, 5120, 9880)) -}) diff --git a/tests/testthat/test-get_metrics.R b/tests/testthat/test-get_metrics.R deleted file mode 100644 index 0519ecba6..000000000 --- a/tests/testthat/test-get_metrics.R +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/tests/testthat/test-utils_data_handling.R b/tests/testthat/test-helper-quantile-interval-range.R similarity index 78% rename from tests/testthat/test-utils_data_handling.R rename to tests/testthat/test-helper-quantile-interval-range.R index 089f40207..3f94f0468 100644 --- a/tests/testthat/test-utils_data_handling.R +++ b/tests/testthat/test-helper-quantile-interval-range.R @@ -1,3 +1,7 @@ +# ============================================================================== +# quantile_to_interval() +# ============================================================================== + test_that("quantile_to_interval_dataframe() works", { quantile <- data.frame( date = as.Date("2020-01-01") + 1:10, @@ -35,7 +39,7 @@ test_that("quantile_to_interval_dataframe() works", { # - after, it's a 'warning' # - the conditionMessage() also differs expected_condition <- tryCatch( - dcast(data.table(a = c(1, 1), b = 2, c = 3), a ~ b, value.var="c"), + dcast(data.table(a = c(1, 1), b = 2, c = 3), a ~ b, value.var = "c"), condition = identity ) expect_condition( @@ -59,73 +63,6 @@ test_that("quantile_to_interval_dataframe() works", { }) -test_that("as_forecast_quantiles works", { - samples <- data.frame( - date = as.Date("2020-01-01") + 1:10, - model = "model1", - observed = 1:10, - predicted = c(rep(0, 10), 2:11, 3:12, 4:13, rep(100, 10)), - sample_id = rep(1:5, each = 10) - ) %>% - as_forecast_sample() - - quantile <- data.frame( - date = rep(as.Date("2020-01-01") + 1:10, each = 2), - model = "model1", - observed = rep(1:10, each = 2), - quantile_level = c(0.25, 0.75), - predicted = rep(2:11, each = 2) + c(0, 2) - ) - - expect_no_condition( - as_forecast_quantile(samples, probs = c(0.25, 0.75)) - ) - - wrongclass <- as_forecast_sample(samples) - class(wrongclass) <- c("forecast_point", "data.table", "data.frame") - expect_error( - as_forecast_quantile(wrongclass, quantile_level = c(0.25, 0.75)), - "Assertion on 'quantile_level' failed: Must be of type" - ) - - - quantile2 <- as_forecast_quantile( - as_forecast_sample(samples), - probs = c(0.25, 0.75) - ) - - expect_equal(quantile, as.data.frame(quantile2)) - - # Verify that `type` is correctly scoped in as_forecast_quantile(), as it is - # also an argument. - # If it's not scoped well, the call to `as_forecast_quantile()` will fail. - samples$type <- "test" - - quantile3 <- as_forecast_quantile( - as_forecast_sample(samples), - probs = c(0.25, 0.75) - ) - quantile3$type <- NULL - - expect_identical( - quantile2, - quantile3 - ) -}) - -test_that("as_forecast_quantiles issue 557 fix", { - - out <- example_sample_discrete %>% - na.omit %>% - as_forecast_quantile( - probs = c(0.01, 0.025, seq(0.05, 0.95, 0.05), 0.975, 0.99) - ) %>% - score() - - expect_equal(any(is.na(out$interval_coverage_deviation)), FALSE) -}) - - test_that("sample_to_range_long works", { samples <- data.frame( date = as.Date("2020-01-01") + 1:10, diff --git a/tests/testthat/test-input-check-helpers.R b/tests/testthat/test-input-check-helpers.R deleted file mode 100644 index 1c00215cf..000000000 --- a/tests/testthat/test-input-check-helpers.R +++ /dev/null @@ -1,87 +0,0 @@ -test_that("Check equal length works if all arguments have length 1", { - out <- interval_score( - observed = 5, - lower = 4, - upper = 6, - interval_range = 95, - weigh = TRUE, - separate_results = FALSE - ) - expect_equal(out, 0.05) -}) - -test_that("check_number_per_forecast works", { - expect_identical( - capture.output( - check_number_per_forecast( - example_binary, forecast_unit = "location_name" - ) - ), - paste( - "[1] \"Some forecasts have different numbers of rows", - "(e.g. quantiles or samples). scoringutils found: 224, 215.", - "This may be a problem (it can potentially distort scores,", - "making it more difficult to compare them),", - "so make sure this is intended.\"" - ) - ) - expect_true( - check_number_per_forecast( - example_binary - ) - ) -}) - - -test_that("check_duplicates works", { - example_bin <- rbind(example_binary[1000:1002, ], example_binary[1000:1002, ]) - expect_identical( - capture.output( - check_duplicates(example_bin) - ), - paste( - "[1] \"There are instances with more than one forecast for the same", - "target. This can't be right and needs to be resolved.", - "Maybe you need to check the unit of a single forecast and add", - "missing columns? Use the function get_duplicate_forecasts() to", - "identify duplicate rows\"" - ) - ) - expect_true( - check_duplicates(example_binary) - ) -}) - -test_that("check_columns_present works", { - expect_identical( - capture.output( - check_columns_present(example_binary, c("loc1", "loc2", "loc3")) - ), - paste( - "[1] \"Columns 'loc1', 'loc2', 'loc3' not found in data\"" - ) - ) - expect_identical( - capture.output( - check_columns_present(example_binary, c("loc1")) - ), - paste( - "[1] \"Column 'loc1' not found in data\"" - ) - ) - expect_true( - check_columns_present(example_binary, c("location_name")) - ) - expect_true( - check_columns_present(example_binary, columns = NULL) - ) -}) - -test_that("test_columns_not_present works", { - expect_true( - test_columns_not_present(example_binary, "sample_id") - ) - expect_false( - test_columns_not_present(example_binary, "location") - ) -}) diff --git a/tests/testthat/test-inputs-scoring-functions.R b/tests/testthat/test-inputs-scoring-functions.R index e9002ebfe..84638e6cb 100644 --- a/tests/testthat/test-inputs-scoring-functions.R +++ b/tests/testthat/test-inputs-scoring-functions.R @@ -1,39 +1,9 @@ -observed <- rnorm(30, mean = 1:30) -interval_range <- rep(90, 30) -alpha <- (100 - interval_range) / 100 -lower <- qnorm(alpha / 2, rnorm(30, mean = 1:30)) -upper <- qnorm((1 - alpha / 2), rnorm(30, mean = 11:40)) -test_that("assert_input_interval() works as expected", { - expect_no_condition( - assert_input_interval(observed, lower, upper, interval_range) - ) - - # expect error if upper < lower - expect_error( - assert_input_interval(observed, upper, lower, interval_range), - "All values in `upper` need to be greater than or equal to the corresponding values in `lower`" - ) - # expect warning if interval range is < 1 - expect_warning( - assert_input_interval(observed, lower, upper, 0.5), - "Found interval ranges between 0 and 1. Are you sure that's right?" - ) -}) - - -test_that("check_input_interval() works as expected", { - expect_no_condition( - check_input_interval(observed, lower, upper, interval_range) - ) - # expect message return if upper < lower - expect_match( - check_input_interval(observed, upper, lower, interval_range), - regexp = "All values in `upper` need to be greater than or equal" - ) -}) +# ============================================================================== +# assert_dims_ok_point() +# ============================================================================== test_that("assert_dims_ok_point() works as expected", { # expect no error if dimensions are ok @@ -53,6 +23,10 @@ test_that("assert_dims_ok_point() works as expected", { }) +# ============================================================================== +# check_input_sample() +# ============================================================================== + test_that("check_input_sample() works as expected", { # expect no error if dimensions are ok expect_true(check_input_sample(1:10, matrix(1:20, nrow = 10))) @@ -62,23 +36,4 @@ test_that("check_input_sample() works as expected", { check_input_sample(1:10, 1:11), "Assertion on 'predicted' failed: Must be of type 'matrix', not 'integer'." ) -}) - -test_that("check_input_quantile() works as expected", { - # expect no error if dimensions are ok - expect_true( - check_input_quantile( - 1:10, matrix(1:20, nrow = 10), - quantile_level = c(0.1, 0.9) - ) - ) - - # expect error if dimensions are not ok - expect_match( - check_input_quantile( - 1:10, matrix(1:20, nrow = 10), - quantile_level = seq(0.1, 0.9, length.out = 8) - ), - "Assertion on 'predicted' failed: Must have exactly 8 cols, but has 2 cols." - ) -}) +}) \ No newline at end of file diff --git a/tests/testthat/test-metrics-interval-range.R b/tests/testthat/test-metrics-interval-range.R new file mode 100644 index 000000000..aeed859b7 --- /dev/null +++ b/tests/testthat/test-metrics-interval-range.R @@ -0,0 +1,39 @@ +observed <- rnorm(30, mean = 1:30) +interval_range <- rep(90, 30) +alpha <- (100 - interval_range) / 100 +lower <- qnorm(alpha / 2, rnorm(30, mean = 1:30)) +upper <- qnorm((1 - alpha / 2), rnorm(30, mean = 11:40)) + + +# ============================================================================== +# assert_input_interval() +# ============================================================================== +test_that("assert_input_interval() works as expected", { + expect_no_condition( + assert_input_interval(observed, lower, upper, interval_range) + ) + + # expect error if upper < lower + expect_error( + assert_input_interval(observed, upper, lower, interval_range), + "All values in `upper` need to be greater than or equal to the corresponding values in `lower`" + ) + + # expect warning if interval range is < 1 + expect_warning( + assert_input_interval(observed, lower, upper, 0.5), + "Found interval ranges between 0 and 1. Are you sure that's right?" + ) +}) + + +test_that("check_input_interval() works as expected", { + expect_no_condition( + check_input_interval(observed, lower, upper, interval_range) + ) + # expect message return if upper < lower + expect_match( + check_input_interval(observed, upper, lower, interval_range), + regexp = "All values in `upper` need to be greater than or equal" + ) +}) diff --git a/tests/testthat/test-metrics-quantile.R b/tests/testthat/test-metrics-quantile.R index 0fbf8a5dc..81039d5d8 100644 --- a/tests/testthat/test-metrics-quantile.R +++ b/tests/testthat/test-metrics-quantile.R @@ -16,6 +16,28 @@ forecast_quantiles_matrix <- rbind( forecast_quantile_probs <- c(0.1, 0.25, 0.5, 0.75, 0.9) +# ============================================================================== +# check_input_quantile() +# ============================================================================== +test_that("check_input_quantile() works as expected", { + # expect no error if dimensions are ok + expect_true( + check_input_quantile( + 1:10, matrix(1:20, nrow = 10), + quantile_level = c(0.1, 0.9) + ) + ) + + # expect error if dimensions are not ok + expect_match( + check_input_quantile( + 1:10, matrix(1:20, nrow = 10), + quantile_level = seq(0.1, 0.9, length.out = 8) + ), + "Assertion on 'predicted' failed: Must have exactly 8 cols, but has 2 cols." + ) +}) + # ============================================================================ # # Input handling =============================================================== # ============================================================================ # diff --git a/tests/testthat/test-metrics-sample.R b/tests/testthat/test-metrics-sample.R index 323fec21e..4476ac1bb 100644 --- a/tests/testthat/test-metrics-sample.R +++ b/tests/testthat/test-metrics-sample.R @@ -214,3 +214,84 @@ test_that("function throws an error when missing 'predicted'", { ) }) + +# ============================================================================ # +# pit_sample() +# ============================================================================ # + +test_that("pit_sample() function throws an error when missing args", { + observed <- rpois(10, lambda = 1:10) + predicted <- replicate(50, rpois(n = 10, lambda = 1:10)) + + expect_error( + pit_sample(predicted = predicted), + 'argument "observed" is missing, with no default' + ) + + expect_error( + pit_sample(observed = observed), + 'argument "predicted" is missing, with no default' + ) +}) + +test_that("pit_sample() function works for integer observed and predicted", { + observed <- rpois(10, lambda = 1:10) + predicted <- replicate(10, rpois(10, lambda = 1:10)) + output <- pit_sample( + observed = observed, + predicted = predicted, + n_replicates = 56 + ) + expect_equal( + length(output), + 560 + ) + + checkmate::expect_class(output, "numeric") +}) + +test_that("pit_sample() function works for continuous observed and predicted", { + observed <- rnorm(10) + predicted <- replicate(10, rnorm(10)) + output <- pit_sample( + observed = observed, + predicted = predicted, + n_replicates = 56 + ) + expect_equal( + length(output), + 10 + ) +}) + +test_that("pit_sample() works with a single observvation", { + expect_no_condition( + output <- pit_sample(observed = 2.5, predicted = 1.5:10.5) + ) + expect_equal(length(output), 1) + + # test discrete case + expect_no_condition( + output2 <- pit_sample( + observed = 3, predicted = 1:10, n_replicates = 24 + ) + ) + expect_equal(length(output2), 24) +}) + +test_that("pit_sample() throws an error if inputs are wrong", { + observed <- 1.5:20.5 + predicted <- replicate(100, 1.5:20.5) + + # expect an error if predicted cannot be coerced to a matrix + expect_error( + pit_sample(observed, function(x) {}), + "Assertion on 'predicted' failed: Must be of type 'matrix'" + ) + + # expect an error if the number of rows in predicted does not match the length of observed + expect_error( + pit_sample(observed, predicted[1:10, ]), + "Assertion on 'predicted' failed: Must have exactly 20 rows, but has 10 rows." + ) +}) diff --git a/tests/testthat/test-metrics-validate.R b/tests/testthat/test-metrics-validate.R deleted file mode 100644 index 0272579de..000000000 --- a/tests/testthat/test-metrics-validate.R +++ /dev/null @@ -1,18 +0,0 @@ -test_that("validate_metrics() works as expected", { - test_fun <- function(x, y, ...) { - if (hasArg("test")) { - message("test argument found") - } - return(y) - } - ## Additional tests for validate_metrics() - # passing in something that's not a function or a known metric - expect_warning( - expect_warning( - score(as_forecast_binary(na.omit(example_binary)), metrics = list( - "test1" = test_fun, "test" = test_fun, "hi" = "hi", "2" = 3) - ), - "`Metrics` element number 3 is not a valid function" - ), - "`Metrics` element number 4 is not a valid function") -}) diff --git a/tests/testthat/test-default-scoring-rules.R b/tests/testthat/test-metrics.R similarity index 92% rename from tests/testthat/test-default-scoring-rules.R rename to tests/testthat/test-metrics.R index 418848cba..cbfb64cc1 100644 --- a/tests/testthat/test-default-scoring-rules.R +++ b/tests/testthat/test-metrics.R @@ -3,7 +3,6 @@ # ============================================================================== test_that("`select_metrics` works as expected", { - expect_equal( scoringutils:::select_metrics(get_metrics(example_point), select = NULL), get_metrics(example_point) @@ -43,6 +42,41 @@ test_that("`select_metrics` works as expected", { }) +# ============================================================================== +# get_metrics() +# ============================================================================== +# See additional tests for individual classes. +test_that("selecting metrics in get_metrics() works as expected", { + expect_equal( + names(get_metrics(example_point, select = "ape")), + "ape" + ) + + expect_equal( + length(get_metrics(example_binary, select = NULL, exclude = "brier_score")), + length(get_metrics(example_binary)) - 1 + ) + + # if both select and exclude are specified, exclude is ignored + expect_equal( + names(scoringutils:::select_metrics(get_metrics(example_quantile), select = "wis", exclude = "wis")), + "wis" + ) + + # expect error if select is not included in the default possibilities + expect_error( + get_metrics(example_sample_continuous, select = "not-included"), + "Must be a subset of" + ) + + # expect error if exclude is not included in the default possibilities + expect_error( + get_metrics(example_quantile, exclude = "not-included"), + "Must be a subset of" + ) +}) + + # ============================================================================== # Customising metrics using purrr::partial() # ============================================================================== @@ -85,48 +119,3 @@ test_that("purrr::partial() has the expected output class", { custom_metric <- purrr::partial(mean, na.rm = TRUE) checkmate::expect_class(custom_metric, "function") }) - - -# ============================================================================== -# default scoring rules -# ============================================================================== - -test_that("default rules work as expected", { - - expect_true( - all(c( - is.list(get_metrics(example_quantile)), - is.list(get_metrics(example_binary)), - is.list(get_metrics(example_sample_continuous)), - is.list(get_metrics(example_point))) - ) - ) - - expect_equal( - names(get_metrics(example_point, select = "ape")), - "ape" - ) - - expect_equal( - length(get_metrics(example_binary, select = NULL, exclude = "brier_score")), - length(get_metrics(example_binary)) - 1 - ) - - # if both select and exclude are specified, exclude is ignored - expect_equal( - names(scoringutils:::select_metrics(get_metrics(example_quantile), select = "wis", exclude = "wis")), - "wis" - ) - - # expect error if select is not included in the default possibilities - expect_error( - get_metrics(example_sample_continuous, select = "not-included"), - "Must be a subset of" - ) - - # expect error if exclude is not included in the default possibilities - expect_error( - get_metrics(example_quantile, exclude = "not-included"), - "Must be a subset of" - ) -}) diff --git a/tests/testthat/test-pairwise_comparison.R b/tests/testthat/test-pairwise_comparison.R index 9228b88e7..dac73812a 100644 --- a/tests/testthat/test-pairwise_comparison.R +++ b/tests/testthat/test-pairwise_comparison.R @@ -523,3 +523,25 @@ test_that("permutation_tests work as expected", { ) ) }) + + +# ============================================================================== +# plot_pairwise_comparison() +# ============================================================================== +pairwise <- get_pairwise_comparisons(scores_quantile, by = "target_type") + +test_that("plot_pairwise_comparisons() works as expected", { + p <- plot_pairwise_comparisons(pairwise) + + ggplot2::facet_wrap(~target_type) + expect_s3_class(p, "ggplot") + skip_on_cran() + vdiffr::expect_doppelganger("plot_pairwise_comparison", p) +}) + +test_that("plot_pairwise_comparisons() works when showing p values", { + p <- plot_pairwise_comparisons(pairwise, type = "pval") + + ggplot2::facet_wrap(~target_type) + expect_s3_class(p, "ggplot") + skip_on_cran() + vdiffr::expect_doppelganger("plot_pairwise_comparison_pval", p) +}) diff --git a/tests/testthat/test-pit.R b/tests/testthat/test-pit.R deleted file mode 100644 index 2bb1fc885..000000000 --- a/tests/testthat/test-pit.R +++ /dev/null @@ -1,111 +0,0 @@ -# ============================================================================ # -# pit_sample() -# ============================================================================ # - -test_that("pit_sample() function throws an error when missing args", { - observed <- rpois(10, lambda = 1:10) - predicted <- replicate(50, rpois(n = 10, lambda = 1:10)) - - expect_error( - pit_sample(predicted = predicted), - 'argument "observed" is missing, with no default' - ) - - expect_error( - pit_sample(observed = observed), - 'argument "predicted" is missing, with no default' - ) -}) - -test_that("pit_sample() function works for integer observed and predicted", { - observed <- rpois(10, lambda = 1:10) - predicted <- replicate(10, rpois(10, lambda = 1:10)) - output <- pit_sample( - observed = observed, - predicted = predicted, - n_replicates = 56 - ) - expect_equal( - length(output), - 560 - ) - - checkmate::expect_class(output, "numeric") -}) - -test_that("pit_sample() function works for continuous observed and predicted", { - observed <- rnorm(10) - predicted <- replicate(10, rnorm(10)) - output <- pit_sample( - observed = observed, - predicted = predicted, - n_replicates = 56 - ) - expect_equal( - length(output), - 10 - ) -}) - -test_that("pit_sample() works with a single observvation", { - expect_no_condition( - output <- pit_sample(observed = 2.5, predicted = 1.5:10.5) - ) - expect_equal(length(output), 1) - - # test discrete case - expect_no_condition( - output2 <- pit_sample( - observed = 3, predicted = 1:10, n_replicates = 24 - ) - ) - expect_equal(length(output2), 24) -}) - - -# ============================================================================ # -# get_pit() -# ============================================================================ # - -test_that("pit function works for continuous integer and quantile data", { - pit_quantile <- suppressMessages(as_forecast_quantile(example_quantile)) %>% - get_pit(by = "model") - pit_continuous <- suppressMessages(as_forecast_sample(example_sample_continuous)) %>% - get_pit(by = c("model", "target_type")) - pit_integer <- suppressMessages(as_forecast_sample(example_sample_discrete)) %>% - get_pit(by = c("model", "location")) - - expect_equal(names(pit_quantile), c("model", "quantile_level", "pit_value")) - expect_equal(names(pit_continuous), c("model", "target_type", "pit_value")) - expect_equal(names(pit_integer), c("model", "location", "pit_value")) - - expect_s3_class(pit_quantile, c("data.table", "data.frame"), exact = TRUE) - - # check printing works - testthat::expect_output(print(pit_quantile)) - testthat::expect_output(print(pit_continuous)) - testthat::expect_output(print(pit_integer)) - - # check class is correct - expect_s3_class(pit_quantile, c("data.table", "data.frame"), exact = TRUE) - expect_s3_class(pit_continuous, c("data.table", "data.frame"), exact = TRUE) - expect_s3_class(pit_integer, c("data.table", "data.frame"), exact = TRUE) -}) - -test_that("pit_sample() throws an error if inputs are wrong", { - observed <- 1.5:20.5 - predicted <- replicate(100, 1.5:20.5) - - # expect an error if predicted cannot be coerced to a matrix - expect_error( - pit_sample(observed, function(x) {}), - "Assertion on 'predicted' failed: Must be of type 'matrix'" - ) - - # expect an error if the number of rows in predicted does not match the length of observed - expect_error( - pit_sample(observed, predicted[1:10, ]), - "Assertion on 'predicted' failed: Must have exactly 20 rows, but has 10 rows." - ) -}) - diff --git a/tests/testthat/test-plot_avail_forecasts.R b/tests/testthat/test-plot_avail_forecasts.R deleted file mode 100644 index 39ae5e5b5..000000000 --- a/tests/testthat/test-plot_avail_forecasts.R +++ /dev/null @@ -1,14 +0,0 @@ -test_that("plot.forecast_counts() works as expected", { - available_forecasts <- na.omit(example_quantile) %>% - as_forecast_quantile() %>% - get_forecast_counts( - by = c("model", "target_type", "target_end_date") - ) - p <- plot_forecast_counts(available_forecasts, - x = "target_end_date", show_counts = FALSE - ) + - facet_wrap("target_type") - expect_s3_class(p, "ggplot") - skip_on_cran() - vdiffr::expect_doppelganger("plot_available_forecasts", p) -}) diff --git a/tests/testthat/test-plot_correlation.R b/tests/testthat/test-plot_correlation.R deleted file mode 100644 index 5c03329a7..000000000 --- a/tests/testthat/test-plot_correlation.R +++ /dev/null @@ -1,18 +0,0 @@ -test_that("plot_correlations() works as expected", { - correlations <- get_correlations( - summarise_scores( - scores_quantile, - by = get_forecast_unit(scores_quantile) - ) - ) - p <- plot_correlations(correlations, digits = 2) - expect_s3_class(p, "ggplot") - skip_on_cran() - vdiffr::expect_doppelganger("plot__correlation", p) - - # expect an error if you forgot to compute correlations - expect_error( - plot_correlations(summarise_scores(scores_quantile)), - "Did you forget to call `scoringutils::get_correlations()`?" - ) -}) diff --git a/tests/testthat/test-plot_interval_coverage.R b/tests/testthat/test-plot_interval_coverage.R deleted file mode 100644 index fbe76a519..000000000 --- a/tests/testthat/test-plot_interval_coverage.R +++ /dev/null @@ -1,15 +0,0 @@ -test_that("plot_interval_coverage() works as expected", { - coverage <- example_quantile %>% - na.omit() %>% - as_forecast_quantile() %>% - get_coverage(by = c("model")) - p <- plot_interval_coverage(coverage) - expect_s3_class(p, "ggplot") - skip_on_cran() - suppressWarnings(vdiffr::expect_doppelganger("plot_interval_coverage", p)) - - # make sure that plot_interval_coverage() doesn't drop column names - expect_true(all(c("interval_coverage", "interval_coverage_deviation", - "quantile_coverage", "quantile_coverage_deviation") %in% - names(coverage))) -}) diff --git a/tests/testthat/test-plot_pairwise_comparison.R b/tests/testthat/test-plot_pairwise_comparison.R deleted file mode 100644 index 88bac4fa5..000000000 --- a/tests/testthat/test-plot_pairwise_comparison.R +++ /dev/null @@ -1,19 +0,0 @@ -pairwise <- suppressMessages( - get_pairwise_comparisons(scores_quantile, by = "target_type") -) - -test_that("plot_pairwise_comparisons() works as expected", { - p <- plot_pairwise_comparisons(pairwise) + - ggplot2::facet_wrap(~target_type) - expect_s3_class(p, "ggplot") - skip_on_cran() - vdiffr::expect_doppelganger("plot_pairwise_comparison", p) -}) - -test_that("plot_pairwise_comparisons() works when showing p values", { - p <- plot_pairwise_comparisons(pairwise, type = "pval") + - ggplot2::facet_wrap(~target_type) - expect_s3_class(p, "ggplot") - skip_on_cran() - vdiffr::expect_doppelganger("plot_pairwise_comparison_pval", p) -}) diff --git a/tests/testthat/test-plot_pit.R b/tests/testthat/test-plot_pit.R deleted file mode 100644 index c1493b177..000000000 --- a/tests/testthat/test-plot_pit.R +++ /dev/null @@ -1,37 +0,0 @@ -test_that("plot_pit() works as expected with quantile forecasts", { - pit <- example_quantile %>% - na.omit() %>% - as_forecast_quantile() %>% - get_pit(by = "model") - p <- plot_pit(pit, breaks = seq(0.1, 1, 0.1)) - expect_s3_class(p, "ggplot") - skip_on_cran() - vdiffr::expect_doppelganger("plot_pit_quantile", p) - - p2 <- plot_pit(pit) - expect_s3_class(p2, "ggplot") - skip_on_cran() - vdiffr::expect_doppelganger("plot_pit_quantile_2", p2) -}) - -test_that("plot_pit() works as expected with integer forecasts", { - set.seed(587) - pit <- example_sample_discrete %>% - na.omit() %>% - as_forecast_sample() %>% - get_pit(by = "model") - p <- plot_pit(pit) - expect_s3_class(p, "ggplot") - skip_on_cran() - vdiffr::expect_doppelganger("plot_pit_integer", p) -}) - -test_that("plot_pit() works as expected with sample forecasts", { - observed <- rnorm(30, mean = 1:30) - predicted <- replicate(200, rnorm(n = 30, mean = 1:30)) - pit <- pit_sample(observed, predicted) - p <- plot_pit(pit) - expect_s3_class(p, "ggplot") - skip_on_cran() - vdiffr::expect_doppelganger("plot_pit_sample", p) -}) diff --git a/tests/testthat/test-plot_quantile_coverage.R b/tests/testthat/test-plot_quantile_coverage.R deleted file mode 100644 index 1060b0d9f..000000000 --- a/tests/testthat/test-plot_quantile_coverage.R +++ /dev/null @@ -1,11 +0,0 @@ -test_that("plot_quantile_coverage() works as expected", { - coverage <- example_quantile %>% - na.omit() %>% - as_forecast_quantile() %>% - get_coverage(by = c("model", "quantile_level")) - - p <- plot_quantile_coverage(coverage) - expect_s3_class(p, "ggplot") - skip_on_cran() - suppressWarnings(vdiffr::expect_doppelganger("plot_quantile_coverage", p)) -}) diff --git a/tests/testthat/test-print.R b/tests/testthat/test-print.R deleted file mode 100644 index f4720d224..000000000 --- a/tests/testthat/test-print.R +++ /dev/null @@ -1,46 +0,0 @@ -test_that("print() works on forecast_* objects", { - # Check print works on each forecast object - test_dat <- list(example_binary, example_quantile, - example_point, example_sample_continuous, - example_sample_discrete) - test_dat <- lapply(test_dat, na.omit) - for (dat in test_dat){ - forecast_type <- scoringutils:::get_forecast_type(dat) - forecast_unit <- get_forecast_unit(dat) - - fn_name <- paste0("as_forecast_", forecast_type) - fn <- get(fn_name) - dat <- suppressWarnings(suppressMessages(do.call(fn, list(dat)))) - - # Check Forecast type - expect_snapshot(print(dat)) - expect_snapshot(print(dat)) - # Check Forecast unit - expect_snapshot(print(dat)) - expect_snapshot(print(dat)) - - # Check print.data.table works. - output_original <- suppressMessages(capture.output(print(dat))) - output_test <- suppressMessages(capture.output(print(data.table(dat)))) - expect_contains(output_original, output_test) - } -}) - -test_that("print() throws the expected messages", { - test <- data.table::copy(example_point) - class(test) <- c("point", "forecast", "data.table", "data.frame") - suppressMessages( - expect_message( - capture.output(print(test)), - "Could not determine forecast type due to error in validation." - ) - ) - - class(test) <- c("forecast_point", "forecast") - suppressMessages( - expect_message( - capture.output(print(test)), - "Could not determine forecast unit." - ) - ) -}) diff --git a/tests/testthat/test-score.R b/tests/testthat/test-score.R index 3a2f35943..ac29d1a41 100644 --- a/tests/testthat/test-score.R +++ b/tests/testthat/test-score.R @@ -22,18 +22,18 @@ test_that("as_scores() works", { ) }) -test_that("validate_scores() works", { +test_that("assert_scores() works", { expect_error( - validate_scores(data.frame()), + assert_scores(data.frame()), "Must inherit from class 'scores'" ) }) test_that("Output of `score()` has the class `scores()`", { - expect_no_condition(validate_scores(scores_point)) - expect_no_condition(validate_scores(scores_binary)) - expect_no_condition(validate_scores(scores_sample_continuous)) - expect_no_condition(validate_scores(scores_quantile)) + expect_no_condition(assert_scores(scores_point)) + expect_no_condition(assert_scores(scores_binary)) + expect_no_condition(assert_scores(scores_sample_continuous)) + expect_no_condition(assert_scores(scores_quantile)) }) # ============================================================================= @@ -71,289 +71,6 @@ test_that("Manipulating scores objects with .[ works as expected", { }) -# test binary case ------------------------------------------------------------- -test_that("function produces output for a binary case", { - - expect_equal( - names(scores_binary), - c(get_forecast_unit(example_binary), names(get_metrics(example_binary))) - ) - - eval <- summarise_scores(scores_binary, by = c("model", "target_type")) - - expect_equal( - nrow(eval) > 1, - TRUE - ) - expect_equal( - colnames(eval), - c( - "model", "target_type", - "brier_score", - "log_score" - ) - ) - - expect_true("brier_score" %in% names(eval)) - - expect_s3_class(eval, c("scores", "data.table", "data.frame"), exact = TRUE) -}) - -test_that("score.forecast_binary() errors with only NA values", { - # [.forecast()` will warn even before score() - only_nas <- suppressWarnings( - copy(example_binary)[, predicted := NA_real_] - ) - expect_error( - score(only_nas), - "After removing rows with NA values in the data, no forecasts are left." - ) -}) - -test_that("score() gives same result for binary as regular function", { - manual_eval <- brier_score( - factor(example_binary$observed), - example_binary$predicted - ) - expect_equal(scores_binary$brier_score, manual_eval[!is.na(manual_eval)]) -}) - -test_that( - "passing additional functions to score binary works handles them", { - test_fun <- function(x, y, ...) { - if (hasArg("test")) { - message("test argument found") - } - return(y) - } - - df <- example_binary[model == "EuroCOVIDhub-ensemble" & - target_type == "Cases" & location == "DE"] %>% - as_forecast_binary() - - # passing a simple function works - expect_equal( - score(df, - metrics = list("identity" = function(x, y) {return(y)}))$identity, - df$predicted - ) - } -) - -# test point case -------------------------------------------------------------- -test_that("function produces output for a point case", { - expect_equal( - names(scores_binary), - c(get_forecast_unit(example_binary), names(get_metrics(example_binary))) - ) - - eval <- summarise_scores(scores_point, by = c("model", "target_type")) - - expect_equal( - nrow(eval) > 1, - TRUE - ) - expect_equal( - colnames(eval), - c("model", "target_type", names(get_metrics(example_point))) - ) - - expect_s3_class(eval, c("scores", "data.table", "data.frame"), exact = TRUE) -}) - -test_that("Changing metrics names works", { - metrics_test <- get_metrics(example_point) - names(metrics_test)[1] = "just_testing" - eval <- suppressMessages(score(as_forecast_point(example_point), - metrics = metrics_test)) - eval_summarised <- summarise_scores(eval, by = "model") - expect_equal( - colnames(eval_summarised), - c("model", "just_testing", names(get_metrics(example_point))[-1]) - ) -}) - - -test_that("score.forecast_point() errors with only NA values", { - # [.forecast()` will warn even before score() - only_nas <- suppressWarnings( - copy(example_point)[, predicted := NA_real_] - ) - expect_error( - score(only_nas), - "After removing rows with NA values in the data, no forecasts are left." - ) -}) - -# test quantile case ----------------------------------------------------------- -test_that("score_quantile correctly handles separate results = FALSE", { - df <- example_quantile[model == "EuroCOVIDhub-ensemble" & - target_type == "Cases" & location == "DE"] - metrics <- get_metrics(example_quantile) - metrics$wis <- purrr::partial(wis, separate_results = FALSE) - eval <- score(df[!is.na(predicted)], metrics = metrics) - - expect_equal( - nrow(eval) > 1, - TRUE - ) - expect_true(all(names(get_metrics(example_quantile)) %in% colnames(eval))) - - expect_s3_class(eval, c("scores", "data.table", "data.frame"), exact = TRUE) -}) - - -test_that("score() quantile produces desired metrics", { - data <- data.frame( - observed = rep(1:10, each = 3), - predicted = rep(c(-0.3, 0, 0.3), 10) + rep(1:10, each = 3), - model = "Model 1", - date = as.Date("2020-01-01") + rep(1:10, each = 3), - quantile_level = rep(c(0.1, 0.5, 0.9), times = 10) - ) - - data <-suppressWarnings(suppressMessages(as_forecast_quantile(data))) - - out <- score(forecast = data, metrics = metrics_no_cov) - metrics <- c( - "dispersion", "underprediction", "overprediction", - "bias", "ae_median" - ) - - expect_true(all(metrics %in% colnames(out))) -}) - - -test_that("calculation of ae_median is correct for a quantile format case", { - eval <- summarise_scores(scores_quantile,by = "model") - - example <- as.data.table(example_quantile) - ae <- example[quantile_level == 0.5, ae := abs(observed - predicted)][!is.na(model), .(mean = mean(ae, na.rm = TRUE)), - by = "model" - ]$mean - - expect_equal(sort(eval$ae_median), sort(ae)) -}) - - -test_that("all quantile and range formats yield the same result", { - eval1 <- summarise_scores(scores_quantile, by = "model") - - df <- as.data.table(example_quantile) - - ae <- df[ - quantile_level == 0.5, ae := abs(observed - predicted)][ - !is.na(model), .(mean = mean(ae, na.rm = TRUE)), - by = "model" - ]$mean - - expect_equal(sort(eval1$ae_median), sort(ae)) -}) - -test_that("WIS is the same with other metrics omitted or included", { - eval <- score(example_quantile, - metrics = list("wis" = wis) - ) - - eval2 <- scores_quantile - - expect_equal( - sum(eval$wis), - sum(eval2$wis) - ) -}) - - -test_that("score.forecast_quantile() errors with only NA values", { - # [.forecast()` will warn even before score() - only_nas <- suppressWarnings( - copy(example_quantile)[, predicted := NA_real_] - ) - expect_error( - score(only_nas), - "After removing rows with NA values in the data, no forecasts are left." - ) -}) - -test_that("score.forecast_quantile() works as expected in edge cases", { - # only the median - onlymedian <- example_quantile[quantile_level == 0.5] - expect_no_condition( - s <- score(onlymedian, metrics = get_metrics( - example_quantile, - exclude = c("interval_coverage_50", "interval_coverage_90") - )) - ) - expect_equal( - s$wis, abs(onlymedian$observed - onlymedian$predicted) - ) - - # only one symmetric interval is present - oneinterval <- example_quantile[quantile_level %in% c(0.25,0.75)] %>% - as_forecast_quantile() - expect_message( - s <- score( - oneinterval, - metrics = get_metrics( - example_quantile, - exclude = c("interval_coverage_90", "ae_median") - ) - ), - "Median not available" - ) -}) - - - -test_that("score() works even if only some quantiles are missing", { - - # only the median is there - onlymedian <- example_quantile[quantile_level == 0.5] - expect_no_condition( - score(onlymedian, metrics = get_metrics( - example_quantile, - exclude = c("interval_coverage_50", "interval_coverage_90") - )) - ) - - - # asymmetric intervals - asymm <- example_quantile[!quantile_level > 0.6] - expect_warning( - expect_warning( - score_a <- score(asymm) %>% summarise_scores(by = "model"), - "Computation for `interval_coverage_50` failed." - ), - "Computation for `interval_coverage_90` failed." - ) - - # check that the result is equal to a case where we discard the entire - # interval in terms of WIS - inner <- example_quantile[quantile_level %in% c(0.4, 0.45, 0.5, 0.55, 0.6)] - score_b <- score(inner, get_metrics( - inner, exclude = c("interval_coverage_50", "interval_coverage_90") - )) %>% - summarise_scores(by = "model") - expect_equal( - score_a$wis, - score_b$wis - ) - - # median is not there, but only in a single model - test <- data.table::copy(example_quantile) - test_no_median <- test[model == "epiforecasts-EpiNow2" & !(quantile_level %in% c(0.5)), ] - test <- rbind(test[model != "epiforecasts-EpiNow2"], test_no_median) - - test <- suppressWarnings(as_forecast_quantile(test)) - expect_message( - expect_warning( - score(test), - "Computation for `ae_median` failed." - ), - "interpolating median from the two innermost quantiles" - ) -}) - # test integer and continuous case --------------------------------------------- test_that("function produces output for a continuous format case", { @@ -444,11 +161,52 @@ test_that("`[` preserves attributes", { # ============================================================================= -# validate_scores() +# assert_scores() # ============================================================================= -test_that("validate_scores() works", { - expect_no_condition(validate_scores(scores_binary)) +test_that("assert_scores() works", { + expect_no_condition(assert_scores(scores_binary)) expect_null( - validate_scores(scores_binary), + assert_scores(scores_binary), + ) +}) + +# ============================================================================== +# validate_metrics() +# ============================================================================== +test_that("validate_metrics() works as expected", { + test_fun <- function(x, y, ...) { + if (hasArg("test")) { + message("test argument found") + } + return(y) + } + ## Additional tests for validate_metrics() + # passing in something that's not a function or a known metric + expect_warning( + expect_warning( + score(as_forecast_binary(na.omit(example_binary)), metrics = list( + "test1" = test_fun, "test" = test_fun, "hi" = "hi", "2" = 3 + )), + "`Metrics` element number 3 is not a valid function" + ), + "`Metrics` element number 4 is not a valid function" ) }) + + +# ============================================================================== +# run_safely() +# ============================================================================== +test_that("run_safely() works as expected", { + f <- function(x) { + x + } + expect_equal(run_safely(2, fun = f), 2) + expect_equal(run_safely(2, y = 3, fun = f), 2) + expect_warning( + run_safely(fun = f, metric_name = "f"), + 'Computation for `f` failed. Error: argument "x" is missing, with no default', + fixed = TRUE + ) + expect_equal(suppressWarnings(run_safely(y = 3, fun = f, metric_name = "f")), NULL) +}) \ No newline at end of file diff --git a/tests/testthat/test-summarise_scores.R b/tests/testthat/test-summarise_scores.R index 420781fea..5186f7490 100644 --- a/tests/testthat/test-summarise_scores.R +++ b/tests/testthat/test-summarise_scores.R @@ -66,7 +66,6 @@ test_that("summarise_scores() handles data.frames correctly", { ) }) - test_that("summarise_scores() errors if `by = NULL", { expect_error( summarise_scores(scores_quantile, by = NULL), diff --git a/tests/testthat/test-convenience-functions.R b/tests/testthat/test-transform-forecasts.R similarity index 54% rename from tests/testthat/test-convenience-functions.R rename to tests/testthat/test-transform-forecasts.R index e8a3d5756..9772b3853 100644 --- a/tests/testthat/test-convenience-functions.R +++ b/tests/testthat/test-transform-forecasts.R @@ -1,7 +1,6 @@ # ============================================================================ # # `transform_forecasts()` # ============================================================================ # - test_that("function transform_forecasts works", { predictions_original <- example_quantile$predicted predictions <- example_quantile %>% @@ -65,7 +64,6 @@ test_that("transform_forecasts() outputs an object of class forecast_*", { # ============================================================================ # # `log_shift()` # ============================================================================ # - test_that("log_shift() works as expected", { expect_equal(log_shift(1:10, 1), log(1:10 + 1)) @@ -93,87 +91,3 @@ test_that("log_shift() works as expected", { # test output class is numeric as expected checkmate::expect_class(log_shift(1:10, 1), "numeric") }) - - -# ============================================================================ # -# `set_forecast_unit()` -# ============================================================================ # - -test_that("function set_forecast_unit() works", { - # some columns in the example data have duplicated information. So we can remove - # these and see whether the result stays the same. - scores1 <- scores_quantile[order(location, target_end_date, target_type, horizon, model), ] - - # test that if setting the forecast unit results in an invalid object, - # a warning occurs. - expect_warning( - set_forecast_unit(example_quantile, "model"), - "Assertion on 'data' failed: There are instances with more" - ) - - ex2 <- set_forecast_unit( - example_quantile, - c("location", "target_end_date", "target_type", "horizon", "model") - ) - scores2 <- score(na.omit(ex2)) - scores2 <- scores2[order(location, target_end_date, target_type, horizon, model), ] - - expect_equal(scores1$interval_score, scores2$interval_score) -}) - -test_that("set_forecast_unit() works on input that's not a data.table", { - df <- data.frame( - a = 1:2, - b = 2:3, - c = 3:4 - ) - expect_equal( - colnames(set_forecast_unit(df, c("a", "b"))), - c("a", "b") - ) - - expect_equal( - names(set_forecast_unit(as.matrix(df), "a")), - "a" - ) - - expect_s3_class( - set_forecast_unit(df, c("a", "b")), - c("data.table", "data.frame"), - exact = TRUE - ) -}) - -test_that("set_forecast_unit() revalidates a forecast object", { - obj <- as_forecast_quantile(na.omit(example_quantile)) - expect_no_condition( - set_forecast_unit(obj, c("location", "target_end_date", "target_type", "model", "horizon")) - ) -}) - - -test_that("function set_forecast_unit() errors when column is not there", { - expect_error( - set_forecast_unit( - example_quantile, - c("location", "target_end_date", "target_type", "horizon", "model", "test1", "test2") - ), - "Assertion on 'forecast_unit' failed: Must be a subset of " - ) -}) - -test_that("function get_forecast_unit() and set_forecast_unit() work together", { - fu_set <- c("location", "target_end_date", "target_type", "horizon", "model") - ex <- set_forecast_unit(example_binary, fu_set) - fu_get <- get_forecast_unit(ex) - expect_equal(fu_set, fu_get) -}) - -test_that("output class of set_forecast_unit() is as expected", { - ex <- as_forecast_binary(na.omit(example_binary)) - expect_equal( - class(ex), - class(set_forecast_unit(ex, c("location", "target_end_date", "target_type", "horizon", "model"))) - ) -}) -