From 131fbf401906ef733503eef3bdf12c2466cc4f97 Mon Sep 17 00:00:00 2001 From: nikosbosse Date: Tue, 26 Mar 2024 18:42:15 +0100 Subject: [PATCH 1/6] Create function to assert forecast type is as expected --- R/get_-functions.R | 25 +++++++++++++++++++++++++ man/assert_forecast_type.Rd | 23 +++++++++++++++++++++++ 2 files changed, 48 insertions(+) create mode 100644 man/assert_forecast_type.Rd diff --git a/R/get_-functions.R b/R/get_-functions.R index adc188c70..af4eff847 100644 --- a/R/get_-functions.R +++ b/R/get_-functions.R @@ -98,6 +98,31 @@ test_forecast_type_is_quantile <- function(data) { } +#' Assert that forecast type is as expected +#' @param data A forecast object as produced by [as_forecast()]. +#' @inheritParams as_forecast +#' @inherit document_assert_functions return +#' @importFrom cli cli_abort +#' @importFrom checkmate assert_character +assert_forecast_type <- function(data, forecast_type = NULL) { + assert_character(forecast_type, null.ok = TRUE) + desired <- forecast_type + forecast_type <- get_forecast_type(data) + if (!is.null(desired) && desired != forecast_type) { + #nolint start: object_usage_linter keyword_quote_linter + cli_abort( + c( + "!" = "Forecast type determined by scoringutils based on input: + {.val {forecast_type}}.", + "i" = "Desired forecast type: {.val {desired}}." + ) + ) + #nolint end + } + return(invisible(NULL)) +} + + #' @title Get type of a vector or matrix of observed values or predictions #' #' @description Internal helper function to get the type of a vector (usually diff --git a/man/assert_forecast_type.Rd b/man/assert_forecast_type.Rd new file mode 100644 index 000000000..b5a99bda9 --- /dev/null +++ b/man/assert_forecast_type.Rd @@ -0,0 +1,23 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/get_-functions.R +\name{assert_forecast_type} +\alias{assert_forecast_type} +\title{Assert that forecast type is as expected} +\usage{ +assert_forecast_type(data, forecast_type = NULL) +} +\arguments{ +\item{data}{A forecast object as produced by \code{\link[=as_forecast]{as_forecast()}}.} + +\item{forecast_type}{(optional) The forecast type you expect the forecasts +to have. If the forecast type as determined by \code{scoringutils} based on the +input does not match this, an error will be thrown. If \code{NULL} (the default), +the forecast type will be inferred from the data.} +} +\value{ +Returns NULL invisibly if the assertion was successful and throws an +error otherwise. +} +\description{ +Assert that forecast type is as expected +} From 470579326538e9552e13ba27c6657f64b88e42af Mon Sep 17 00:00:00 2001 From: nikosbosse Date: Tue, 26 Mar 2024 18:42:41 +0100 Subject: [PATCH 2/6] Add additional arg to `validate_forecast()` to check for forecast type --- R/forecast.R | 35 ++++++++++++++--------------------- man/validate_forecast.Rd | 11 ++++++++--- 2 files changed, 22 insertions(+), 24 deletions(-) diff --git a/R/forecast.R b/R/forecast.R index 16e963bb6..c4085a1c8 100644 --- a/R/forecast.R +++ b/R/forecast.R @@ -117,21 +117,8 @@ as_forecast.default <- function(data, data <- set_forecast_unit(data, forecast_unit) } - # find forecast type - desired <- forecast_type - forecast_type <- get_forecast_type(data) - - if (!is.null(desired) && desired != forecast_type) { - #nolint start: object_usage_linter keyword_quote_linter - cli_abort( - c( - "!" = "Forecast type determined by scoringutils based on input: - {.val {forecast_type}}.", - "i" = "Desired forecast type: {.val {desired}}." - ) - ) - #nolint end - } + # assert forecast type is as expected + assert_forecast_type(data, forecast_type) # produce warning if old format is suspected # old quantile format @@ -194,7 +181,7 @@ as_forecast.default <- function(data, #' @examples #' forecast <- as_forecast(example_binary) #' validate_forecast(forecast) -validate_forecast <- function(data, ...) { +validate_forecast <- function(data, forecast_type = NULL, ...) { UseMethod("validate_forecast") } @@ -202,7 +189,7 @@ validate_forecast <- function(data, ...) { #' @importFrom cli cli_abort #' @export #' @keywords check-forecasts -validate_forecast.default <- function(data, ...) { +validate_forecast.default <- function(data, forecast_type = NULL, ...) { cli_abort( c( "!" = "The input needs to be a forecast object.", @@ -215,8 +202,9 @@ validate_forecast.default <- function(data, ...) { #' @export #' @importFrom cli cli_abort #' @keywords check-forecasts -validate_forecast.forecast_binary <- function(data, ...) { +validate_forecast.forecast_binary <- function(data, forecast_type = NULL, ...) { data <- validate_general(data) + assert_forecast_type(data, forecast_type) columns_correct <- test_columns_not_present( data, c("sample_id", "quantile_level") @@ -248,8 +236,9 @@ validate_forecast.forecast_binary <- function(data, ...) { #' @export #' @importFrom cli cli_abort #' @keywords check-forecasts -validate_forecast.forecast_point <- function(data, ...) { +validate_forecast.forecast_point <- function(data, forecast_type = NULL, ...) { data <- validate_general(data) + assert_forecast_type(data, forecast_type) #nolint start: keyword_quote_linter object_usage_linter input_check <- check_input_point(data$observed, data$predicted) if (!is.logical(input_check)) { @@ -268,8 +257,10 @@ validate_forecast.forecast_point <- function(data, ...) { #' @export #' @rdname validate_forecast #' @keywords check-forecasts -validate_forecast.forecast_quantile <- function(data, ...) { +validate_forecast.forecast_quantile <- function(data, + forecast_type = NULL, ...) { data <- validate_general(data) + assert_forecast_type(data, forecast_type) assert_numeric(data$quantile_level, lower = 0, upper = 1) return(data[]) } @@ -278,8 +269,10 @@ validate_forecast.forecast_quantile <- function(data, ...) { #' @export #' @rdname validate_forecast #' @keywords check-forecasts -validate_forecast.forecast_sample <- function(data, ...) { +validate_forecast.forecast_sample <- function(data, forecast_type = NULL, ...) { + data <- validate_general(data) + assert_forecast_type(data, forecast_type) return(data[]) } diff --git a/man/validate_forecast.Rd b/man/validate_forecast.Rd index 6b4f7ed2a..ca98820c7 100644 --- a/man/validate_forecast.Rd +++ b/man/validate_forecast.Rd @@ -6,16 +6,21 @@ \alias{validate_forecast.forecast_sample} \title{Validate input data} \usage{ -validate_forecast(data, ...) +validate_forecast(data, forecast_type = NULL, ...) -\method{validate_forecast}{forecast_quantile}(data, ...) +\method{validate_forecast}{forecast_quantile}(data, forecast_type = NULL, ...) -\method{validate_forecast}{forecast_sample}(data, ...) +\method{validate_forecast}{forecast_sample}(data, forecast_type = NULL, ...) } \arguments{ \item{data}{A data.frame (or similar) with predicted and observed values. See \code{\link[=as_forecast]{as_forecast()}} for additional information on input formats.} +\item{forecast_type}{(optional) The forecast type you expect the forecasts +to have. If the forecast type as determined by \code{scoringutils} based on the +input does not match this, an error will be thrown. If \code{NULL} (the default), +the forecast type will be inferred from the data.} + \item{...}{additional arguments} } \value{ From 44ddb3ca83c66b1a73cddd413425f69601e2042d Mon Sep 17 00:00:00 2001 From: nikosbosse Date: Tue, 26 Mar 2024 18:48:03 +0100 Subject: [PATCH 3/6] small correction --- R/forecast.R | 1 + 1 file changed, 1 insertion(+) diff --git a/R/forecast.R b/R/forecast.R index c4085a1c8..08982939a 100644 --- a/R/forecast.R +++ b/R/forecast.R @@ -119,6 +119,7 @@ as_forecast.default <- function(data, # assert forecast type is as expected assert_forecast_type(data, forecast_type) + forecast_type <- get_forecast_type(data) # produce warning if old format is suspected # old quantile format From bea81b8778c1c6f00285850022a4301072442440 Mon Sep 17 00:00:00 2001 From: nikosbosse Date: Tue, 26 Mar 2024 18:48:12 +0100 Subject: [PATCH 4/6] create test --- tests/testthat/test-forecast.R | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/tests/testthat/test-forecast.R b/tests/testthat/test-forecast.R index 55c7be9ae..22aaf418f 100644 --- a/tests/testthat/test-forecast.R +++ b/tests/testthat/test-forecast.R @@ -264,3 +264,13 @@ test_that("validate_forecast.forecast_point() works as expected", { "Input looks like a point forecast, but found the following issue" ) }) + +test_that("validate_forecast() complains if the forecast type is wrong", { + test <- na.omit(data.table::copy(example_point)) + test <- as_forecast(test) + expect_error( + validate_forecast(test, forecast_type = "quantile"), + "Forecast type determined by scoringutils based on input:" + ) +}) + From e545da31f883e770523b916f21e9e1ea038e7b2a Mon Sep 17 00:00:00 2001 From: nikosbosse Date: Tue, 26 Mar 2024 19:03:11 +0100 Subject: [PATCH 5/6] add pkgdown keyword to function --- man/assert_forecast_type.Rd | 1 + 1 file changed, 1 insertion(+) diff --git a/man/assert_forecast_type.Rd b/man/assert_forecast_type.Rd index b5a99bda9..2026d6956 100644 --- a/man/assert_forecast_type.Rd +++ b/man/assert_forecast_type.Rd @@ -21,3 +21,4 @@ error otherwise. \description{ Assert that forecast type is as expected } +\keyword{internal_input_check} From 232a83226060605b3a25c3084f5995e84f4cdf38 Mon Sep 17 00:00:00 2001 From: nikosbosse Date: Tue, 26 Mar 2024 19:03:46 +0100 Subject: [PATCH 6/6] update default args to `assert_forecast_type()` --- R/forecast.R | 11 +++++------ R/get_-functions.R | 16 +++++++++------- man/assert_forecast_type.Rd | 9 ++++----- 3 files changed, 18 insertions(+), 18 deletions(-) diff --git a/R/forecast.R b/R/forecast.R index 08982939a..95732c676 100644 --- a/R/forecast.R +++ b/R/forecast.R @@ -118,7 +118,7 @@ as_forecast.default <- function(data, } # assert forecast type is as expected - assert_forecast_type(data, forecast_type) + assert_forecast_type(data, desired = forecast_type) forecast_type <- get_forecast_type(data) # produce warning if old format is suspected @@ -205,7 +205,7 @@ validate_forecast.default <- function(data, forecast_type = NULL, ...) { #' @keywords check-forecasts validate_forecast.forecast_binary <- function(data, forecast_type = NULL, ...) { data <- validate_general(data) - assert_forecast_type(data, forecast_type) + assert_forecast_type(data, actual = "binary", desired = forecast_type) columns_correct <- test_columns_not_present( data, c("sample_id", "quantile_level") @@ -239,7 +239,7 @@ validate_forecast.forecast_binary <- function(data, forecast_type = NULL, ...) { #' @keywords check-forecasts validate_forecast.forecast_point <- function(data, forecast_type = NULL, ...) { data <- validate_general(data) - assert_forecast_type(data, forecast_type) + assert_forecast_type(data, actual = "point", desired = forecast_type) #nolint start: keyword_quote_linter object_usage_linter input_check <- check_input_point(data$observed, data$predicted) if (!is.logical(input_check)) { @@ -261,7 +261,7 @@ validate_forecast.forecast_point <- function(data, forecast_type = NULL, ...) { validate_forecast.forecast_quantile <- function(data, forecast_type = NULL, ...) { data <- validate_general(data) - assert_forecast_type(data, forecast_type) + assert_forecast_type(data, actual = "quantile", desired = forecast_type) assert_numeric(data$quantile_level, lower = 0, upper = 1) return(data[]) } @@ -271,9 +271,8 @@ validate_forecast.forecast_quantile <- function(data, #' @rdname validate_forecast #' @keywords check-forecasts validate_forecast.forecast_sample <- function(data, forecast_type = NULL, ...) { - data <- validate_general(data) - assert_forecast_type(data, forecast_type) + assert_forecast_type(data, actual = "sample", desired = forecast_type) return(data[]) } diff --git a/R/get_-functions.R b/R/get_-functions.R index af4eff847..fa5af1169 100644 --- a/R/get_-functions.R +++ b/R/get_-functions.R @@ -100,20 +100,22 @@ test_forecast_type_is_quantile <- function(data) { #' Assert that forecast type is as expected #' @param data A forecast object as produced by [as_forecast()]. -#' @inheritParams as_forecast +#' @param actual The actual forecast type of the data +#' @param desired The desired forecast type of the data #' @inherit document_assert_functions return #' @importFrom cli cli_abort #' @importFrom checkmate assert_character -assert_forecast_type <- function(data, forecast_type = NULL) { - assert_character(forecast_type, null.ok = TRUE) - desired <- forecast_type - forecast_type <- get_forecast_type(data) - if (!is.null(desired) && desired != forecast_type) { +#' @keywords internal_input_check +assert_forecast_type <- function(data, + actual = get_forecast_type(data), + desired = NULL) { + assert_character(desired, null.ok = TRUE) + if (!is.null(desired) && desired != actual) { #nolint start: object_usage_linter keyword_quote_linter cli_abort( c( "!" = "Forecast type determined by scoringutils based on input: - {.val {forecast_type}}.", + {.val {actual}}.", "i" = "Desired forecast type: {.val {desired}}." ) ) diff --git a/man/assert_forecast_type.Rd b/man/assert_forecast_type.Rd index 2026d6956..d060ac8be 100644 --- a/man/assert_forecast_type.Rd +++ b/man/assert_forecast_type.Rd @@ -4,15 +4,14 @@ \alias{assert_forecast_type} \title{Assert that forecast type is as expected} \usage{ -assert_forecast_type(data, forecast_type = NULL) +assert_forecast_type(data, actual = get_forecast_type(data), desired = NULL) } \arguments{ \item{data}{A forecast object as produced by \code{\link[=as_forecast]{as_forecast()}}.} -\item{forecast_type}{(optional) The forecast type you expect the forecasts -to have. If the forecast type as determined by \code{scoringutils} based on the -input does not match this, an error will be thrown. If \code{NULL} (the default), -the forecast type will be inferred from the data.} +\item{actual}{The actual forecast type of the data} + +\item{desired}{The desired forecast type of the data} } \value{ Returns NULL invisibly if the assertion was successful and throws an