Skip to content

Commit

Permalink
Merge pull request #157 from epiforecasts/master
Browse files Browse the repository at this point in the history
update branch from master
  • Loading branch information
nikosbosse authored Nov 24, 2021
2 parents 9160039 + ba120ae commit 42f4d8b
Show file tree
Hide file tree
Showing 19 changed files with 3,601 additions and 51 deletions.
8 changes: 7 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ Authors@R: c(
role = c("aut"),
email = "[email protected]",
comment = c(ORCID = "0000-0001-8057-8037")),
person(given = "Hugo",
family = "Gruson",
role = c("aut"),
email = "[email protected]",
comment = c(ORCID = "https://orcid.org/0000-0002-4094-1476")),
person(given = "Johannes Bracher",
role = c("ctb"),
email = "[email protected]",
Expand Down Expand Up @@ -65,7 +70,8 @@ Imports:
Suggests:
testthat,
knitr,
rmarkdown
rmarkdown,
vdiffr
RoxygenNote: 7.1.1
URL: https://github.com/epiforecasts/scoringutils, https://epiforecasts.io/scoringutils/
BugReports: https://github.com/epiforecasts/scoringutils/issues
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ S3method(print,scoringutils_check)
export(abs_error)
export(ae_median_quantile)
export(ae_median_sample)
export(available_metrics)
export(bias)
export(brier_score)
export(check_forecasts)
Expand Down
2 changes: 1 addition & 1 deletion R/bias.R
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ bias <- function(true_values, predictions) {
# ============================================

## check whether continuous or integer
if (all.equal(as.vector(predictions), as.integer(predictions)) != TRUE) {
if (!isTRUE(all.equal(as.vector(predictions), as.integer(predictions)))) {
continuous_predictions <- TRUE
} else {
continuous_predictions <- FALSE
Expand Down
9 changes: 5 additions & 4 deletions R/eval_forecasts.R
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,8 @@
#' may want to include 'range', 'quantile' or 'sample', to summarise by
#' range, quantile or sample.
#' @param metrics the metrics you want to have in the output. If `NULL` (the
#' default), all available metrics will be computed.
#' default), all available metrics will be computed. For a list of available
#' metrics see [available_metrics()]
#' @param quantiles numeric vector of quantiles to be returned when summarising.
#' Instead of just returning a mean, quantiles will be returned for the
#' groups specified through `summarise_by`. By default, no quantiles are
Expand Down Expand Up @@ -242,7 +243,7 @@ eval_forecasts <- function(data = NULL,
}

# check metrics to be computed
available_metrics <- list_of_avail_metrics()
available_metrics <- available_metrics()
if (is.null(metrics)) {
metrics <- available_metrics
} else {
Expand All @@ -259,13 +260,13 @@ eval_forecasts <- function(data = NULL,
if (any(grepl("lower", names(data))) | "boundary" %in% names(data) |
"quantile" %in% names(data) | "range" %in% names(data)) {
prediction_type <- "quantile"
} else if (all.equal(data$prediction, as.integer(data$prediction)) == TRUE) {
} else if (isTRUE(all.equal(data$prediction, as.integer(data$prediction)))) {
prediction_type <- "integer"
} else {
prediction_type <- "continuous"
}

if (all.equal(data$true_value, as.integer(data$true_value)) == TRUE) {
if (isTRUE(all.equal(data$true_value, as.integer(data$true_value)))) {
if (all(data$true_value %in% c(0,1)) && all(data$prediction >= 0) && all(data$prediction <= 1)) {
target_type = "binary"
} else {
Expand Down
2 changes: 1 addition & 1 deletion R/eval_forecasts_continuous_integer.R
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ eval_forecasts_sample <- function(data,
pit_plots) {

if (missing(prediction_type)) {
if (all.equal(data$prediction, as.integer(data$prediction)) == TRUE) {
if (isTRUE(all.equal(data$prediction, as.integer(data$prediction)))) {
prediction_type <- "integer"
} else {
prediction_type <- "continuous"
Expand Down
2 changes: 1 addition & 1 deletion R/pairwise-comparisons.R
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ pairwise_comparison <- function(scores,
# usually, by = NULL should be fine and only needs to be specified if there
# are additional columns that are not metrics and not related to the unit of observation
if (is.null(by)) {
all_metrics <- list_of_avail_metrics()
all_metrics <- available_metrics()
by <- setdiff(names(scores), c(all_metrics, "model"))
}

Expand Down
6 changes: 3 additions & 3 deletions R/pit.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
#' number of Monte Carlo samples
#' @param plot logical. If `TRUE`, a histogram of the PIT values will be
#' returned as well
#' @param num_bins the number of bins in the PIT histogram (if `plot == TRUE`)
#' @param num_bins the number of bins in the PIT histogram (if `plot = TRUE`)
#' If not given, the square root of n will be used
#' @param n_replicates the number of tests to perform,
#' each time re-randomising the PIT
Expand Down Expand Up @@ -148,7 +148,7 @@ pit <- function(true_values,

# check data type ------------------------------------------------------------
# check whether continuous or integer
if (all.equal(as.vector(predictions), as.integer(predictions)) != TRUE) {
if (!isTRUE(all.equal(as.vector(predictions), as.integer(predictions)))) {
continuous_predictions <- TRUE
} else {
continuous_predictions <- FALSE
Expand Down Expand Up @@ -230,7 +230,7 @@ pit <- function(true_values,
#' \item `data`: the input data.frame (not including rows where prediction is `NA`),
#' with added columns `pit_p_val` and `pit_sd`
#' \item `hist_PIT` a plot object with the PIT histogram. Only returned
#' if `plot == TRUE`. Call
#' if `plot = TRUE`. Call
#' `plot(PIT(...)$hist_PIT)` to display the histogram.
#' \item `p_values`: all p_values generated from the Anderson-Darling tests on the
#' (randomised) PIT. Only returned if `full_output = TRUE`
Expand Down
68 changes: 35 additions & 33 deletions R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ score_table <- function(summarised_scores,
# identify metrics -----------------------------------------------------------
# identify metrics by looking at which of the available column names
# are metrics. All other variables are treated as identifier variables
all_metrics <- list_of_avail_metrics()
all_metrics <- available_metrics()

metrics <- names(summarised_scores)[names(summarised_scores) %in% all_metrics]
id_vars <- names(summarised_scores)[!(names(summarised_scores) %in% all_metrics)]
Expand Down Expand Up @@ -196,7 +196,7 @@ correlation_plot <- function(scores,
select_metrics = NULL) {

# define possible metrics
all_metrics <- list_of_avail_metrics()
all_metrics <- available_metrics()

# find metrics present
metrics <- names(scores)[names(scores) %in% all_metrics]
Expand Down Expand Up @@ -695,6 +695,7 @@ plot_predictions <- function(data = NULL,
colnames <- colnames(forecasts)
if ("sample" %in% colnames) {
forecasts <- scoringutils::sample_to_range_long(forecasts,
range = range,
keep_quantile_col = FALSE)
} else if ("quantile" %in% colnames) {
forecasts <- scoringutils::quantile_to_range_long(forecasts,
Expand All @@ -711,28 +712,28 @@ plot_predictions <- function(data = NULL,
intervals[, quantile := NULL]
}

# if there isn't any data to plot, return NULL
if (nrow(intervals) == 0) {
return(NULL)
}

# pivot wider and convert range to a factor
intervals <- data.table::dcast(intervals, ... ~ boundary,
value.var = "prediction")
intervals[, range := as.factor(range)]
pal <- grDevices::colorRampPalette(c("lightskyblue1", "steelblue3"))

# plot prediciton rnages
plot <- ggplot2::ggplot(intervals, ggplot2::aes(x = !!ggplot2::sym(x))) +
ggplot2::geom_ribbon(ggplot2::aes(ymin = lower, ymax = upper,
group = range, fill = range),
alpha = 0.4) +
plot <- ggplot2::ggplot(data = data, aes(x = !!ggplot2::sym(x))) +
ggplot2::scale_colour_manual("",values = c("black", "steelblue4")) +
ggplot2::scale_fill_manual("range", values = c("steelblue3",
"lightskyblue3",
"lightskyblue2",
"lightskyblue1")) +
ggplot2::scale_fill_manual(name = "range", values = pal(length(range))) +
ggplot2::theme_light()

if (nrow(intervals) != 0) {
# pivot wider and convert range to a factor
intervals <- data.table::dcast(intervals, ... ~ boundary,
value.var = "prediction")
intervals[, range := factor(range,
levels = sort(unique(range), decreasing = TRUE),
ordered = TRUE)]

# plot prediction ranges
plot <- plot +
ggplot2::geom_ribbon(data = intervals,
ggplot2::aes(ymin = lower, ymax = upper,
group = range, fill = range))
}

# add median in a different colour
if (0 %in% range) {
select_median <- (forecasts$range %in% 0 & forecasts$boundary == "lower")
Expand All @@ -746,6 +747,20 @@ plot_predictions <- function(data = NULL,
}
}

# add true_values
if (nrow(truth_data) > 0) {
plot <- plot +
ggplot2::geom_point(data = truth_data,
ggplot2::aes(y = true_value, colour = "actual"),
size = 0.5) +
ggplot2::geom_line(data = truth_data,
ggplot2::aes(y = true_value, colour = "actual"),
lwd = 0.2)
}

plot <- plot +
ggplot2::labs(x = xlab, y = ylab)

# facet if specified by the user
if (!is.null(facet_formula)) {
if (facet_wrap_or_grid == "facet_wrap") {
Expand All @@ -757,19 +772,6 @@ plot_predictions <- function(data = NULL,
}
}

# add true_values
if (nrow(truth_data) > 0) {
plot <- plot +
ggplot2::labs(x = xlab, y = ylab)

plot <- plot +
ggplot2::geom_point(data = truth_data,
ggplot2::aes(y = true_value, colour = "actual"),
size = 0.5) +
ggplot2::geom_line(data = truth_data,
ggplot2::aes(y = true_value, colour = "actual"),
lwd = 0.2)
}
return(plot)
}

Expand Down
7 changes: 6 additions & 1 deletion R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,12 @@ globalVariables(c("..index",
"g"))


list_of_avail_metrics <- function() {
#' @title Available metrics in scoringutils
#'
#' @return A vector with the name of all available metrics
#' @export

available_metrics <- function() {
available_metrics <- c("ae_point", "aem", "log_score", "sharpness", "bias", "dss", "crps",
"coverage", "coverage_deviation", "quantile_coverage",
"pit_p_val", "pit_sd","interval_score",
Expand Down
14 changes: 14 additions & 0 deletions man/available_metrics.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/eval_forecasts.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/eval_forecasts_binary.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/eval_forecasts_sample.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion man/pit.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

4 changes: 2 additions & 2 deletions man/pit_df.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 42f4d8b

Please sign in to comment.