diff --git a/R/get-forecast-type.R b/R/get-forecast-type.R index 3f6fe1ab..5d2eb5d1 100644 --- a/R/get-forecast-type.R +++ b/R/get-forecast-type.R @@ -4,14 +4,14 @@ #' Character vector of length one with the forecast type. #' @keywords internal_input_check get_forecast_type <- function(forecast) { - classname <- class(forecast)[1] - if (grepl("forecast_", classname, fixed = TRUE)) { - type <- gsub("forecast_", "", classname, fixed = TRUE) - return(type) + classname <- class(forecast) + forecast_class <- classname[grepl("forecast_", classname, fixed = TRUE)] + if (length(forecast_class) == 1) { + return(gsub("forecast_", "", forecast_class, fixed = TRUE)) } else { cli_abort( "Input is not a valid forecast object - (it's first class should begin with `forecast_`)." + (There should be a single class beginning with `forecast_`)." ) } } diff --git a/tests/testthat/test-get-forecast-type.R b/tests/testthat/test-get-forecast-type.R index 51deb134..0368f280 100644 --- a/tests/testthat/test-get-forecast-type.R +++ b/tests/testthat/test-get-forecast-type.R @@ -21,6 +21,11 @@ test_that("get_forecast_type() works as expected", { get_forecast_type(test), "Input is not a valid forecast object", ) + + # get_forecast_type() should still work even if a new class is added + testclassobject <- data.table::copy(example_quantile) + class(testclassobject) <- c("something", class(testclassobject)) + expect_equal(get_forecast_type(testclassobject), "quantile") })