diff --git a/R/forecast.R b/R/forecast.R index 16e963bb6..95732c676 100644 --- a/R/forecast.R +++ b/R/forecast.R @@ -117,22 +117,10 @@ as_forecast.default <- function(data, data <- set_forecast_unit(data, forecast_unit) } - # find forecast type - desired <- forecast_type + # assert forecast type is as expected + assert_forecast_type(data, desired = forecast_type) forecast_type <- get_forecast_type(data) - if (!is.null(desired) && desired != forecast_type) { - #nolint start: object_usage_linter keyword_quote_linter - cli_abort( - c( - "!" = "Forecast type determined by scoringutils based on input: - {.val {forecast_type}}.", - "i" = "Desired forecast type: {.val {desired}}." - ) - ) - #nolint end - } - # produce warning if old format is suspected # old quantile format if (forecast_type == "point" && "quantile" %in% colnames(data)) { @@ -194,7 +182,7 @@ as_forecast.default <- function(data, #' @examples #' forecast <- as_forecast(example_binary) #' validate_forecast(forecast) -validate_forecast <- function(data, ...) { +validate_forecast <- function(data, forecast_type = NULL, ...) { UseMethod("validate_forecast") } @@ -202,7 +190,7 @@ validate_forecast <- function(data, ...) { #' @importFrom cli cli_abort #' @export #' @keywords check-forecasts -validate_forecast.default <- function(data, ...) { +validate_forecast.default <- function(data, forecast_type = NULL, ...) { cli_abort( c( "!" = "The input needs to be a forecast object.", @@ -215,8 +203,9 @@ validate_forecast.default <- function(data, ...) { #' @export #' @importFrom cli cli_abort #' @keywords check-forecasts -validate_forecast.forecast_binary <- function(data, ...) { +validate_forecast.forecast_binary <- function(data, forecast_type = NULL, ...) { data <- validate_general(data) + assert_forecast_type(data, actual = "binary", desired = forecast_type) columns_correct <- test_columns_not_present( data, c("sample_id", "quantile_level") @@ -248,8 +237,9 @@ validate_forecast.forecast_binary <- function(data, ...) { #' @export #' @importFrom cli cli_abort #' @keywords check-forecasts -validate_forecast.forecast_point <- function(data, ...) { +validate_forecast.forecast_point <- function(data, forecast_type = NULL, ...) { data <- validate_general(data) + assert_forecast_type(data, actual = "point", desired = forecast_type) #nolint start: keyword_quote_linter object_usage_linter input_check <- check_input_point(data$observed, data$predicted) if (!is.logical(input_check)) { @@ -268,8 +258,10 @@ validate_forecast.forecast_point <- function(data, ...) { #' @export #' @rdname validate_forecast #' @keywords check-forecasts -validate_forecast.forecast_quantile <- function(data, ...) { +validate_forecast.forecast_quantile <- function(data, + forecast_type = NULL, ...) { data <- validate_general(data) + assert_forecast_type(data, actual = "quantile", desired = forecast_type) assert_numeric(data$quantile_level, lower = 0, upper = 1) return(data[]) } @@ -278,8 +270,9 @@ validate_forecast.forecast_quantile <- function(data, ...) { #' @export #' @rdname validate_forecast #' @keywords check-forecasts -validate_forecast.forecast_sample <- function(data, ...) { +validate_forecast.forecast_sample <- function(data, forecast_type = NULL, ...) { data <- validate_general(data) + assert_forecast_type(data, actual = "sample", desired = forecast_type) return(data[]) } diff --git a/R/get_-functions.R b/R/get_-functions.R index adc188c70..fa5af1169 100644 --- a/R/get_-functions.R +++ b/R/get_-functions.R @@ -98,6 +98,33 @@ test_forecast_type_is_quantile <- function(data) { } +#' Assert that forecast type is as expected +#' @param data A forecast object as produced by [as_forecast()]. +#' @param actual The actual forecast type of the data +#' @param desired The desired forecast type of the data +#' @inherit document_assert_functions return +#' @importFrom cli cli_abort +#' @importFrom checkmate assert_character +#' @keywords internal_input_check +assert_forecast_type <- function(data, + actual = get_forecast_type(data), + desired = NULL) { + assert_character(desired, null.ok = TRUE) + if (!is.null(desired) && desired != actual) { + #nolint start: object_usage_linter keyword_quote_linter + cli_abort( + c( + "!" = "Forecast type determined by scoringutils based on input: + {.val {actual}}.", + "i" = "Desired forecast type: {.val {desired}}." + ) + ) + #nolint end + } + return(invisible(NULL)) +} + + #' @title Get type of a vector or matrix of observed values or predictions #' #' @description Internal helper function to get the type of a vector (usually diff --git a/man/assert_forecast_type.Rd b/man/assert_forecast_type.Rd new file mode 100644 index 000000000..d060ac8be --- /dev/null +++ b/man/assert_forecast_type.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get_-functions.R +\name{assert_forecast_type} +\alias{assert_forecast_type} +\title{Assert that forecast type is as expected} +\usage{ +assert_forecast_type(data, actual = get_forecast_type(data), desired = NULL) +} +\arguments{ +\item{data}{A forecast object as produced by \code{\link[=as_forecast]{as_forecast()}}.} + +\item{actual}{The actual forecast type of the data} + +\item{desired}{The desired forecast type of the data} +} +\value{ +Returns NULL invisibly if the assertion was successful and throws an +error otherwise. +} +\description{ +Assert that forecast type is as expected +} +\keyword{internal_input_check} diff --git a/man/validate_forecast.Rd b/man/validate_forecast.Rd index 6b4f7ed2a..ca98820c7 100644 --- a/man/validate_forecast.Rd +++ b/man/validate_forecast.Rd @@ -6,16 +6,21 @@ \alias{validate_forecast.forecast_sample} \title{Validate input data} \usage{ -validate_forecast(data, ...) +validate_forecast(data, forecast_type = NULL, ...) -\method{validate_forecast}{forecast_quantile}(data, ...) +\method{validate_forecast}{forecast_quantile}(data, forecast_type = NULL, ...) -\method{validate_forecast}{forecast_sample}(data, ...) +\method{validate_forecast}{forecast_sample}(data, forecast_type = NULL, ...) } \arguments{ \item{data}{A data.frame (or similar) with predicted and observed values. See \code{\link[=as_forecast]{as_forecast()}} for additional information on input formats.} +\item{forecast_type}{(optional) The forecast type you expect the forecasts +to have. If the forecast type as determined by \code{scoringutils} based on the +input does not match this, an error will be thrown. If \code{NULL} (the default), +the forecast type will be inferred from the data.} + \item{...}{additional arguments} } \value{ diff --git a/tests/testthat/test-forecast.R b/tests/testthat/test-forecast.R index 55c7be9ae..22aaf418f 100644 --- a/tests/testthat/test-forecast.R +++ b/tests/testthat/test-forecast.R @@ -264,3 +264,13 @@ test_that("validate_forecast.forecast_point() works as expected", { "Input looks like a point forecast, but found the following issue" ) }) + +test_that("validate_forecast() complains if the forecast type is wrong", { + test <- na.omit(data.table::copy(example_point)) + test <- as_forecast(test) + expect_error( + validate_forecast(test, forecast_type = "quantile"), + "Forecast type determined by scoringutils based on input:" + ) +}) +