Skip to content

Commit

Permalink
fix plotting with updated scoringutils (#89)
Browse files Browse the repository at this point in the history
* Xadd linting

* update Rbuildignore

* fix plotting with updated scoringutils

* update docs

* add a simple test for plot_compare_timeseries + fix use of .dots in group_by

* deal with size depreciation in line geoms

* fix final depreciation warning

* fix linting issues

* add pick to namespace

---------

Co-authored-by: Sam <[email protected]>
  • Loading branch information
sbfnk and seabbs authored Mar 1, 2023
1 parent 2b010e0 commit 40731b4
Show file tree
Hide file tree
Showing 8 changed files with 71 additions and 41 deletions.
2 changes: 1 addition & 1 deletion .Rbuildignore
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,4 @@
^_pkgdown\.yml$
^pkgdown$
^.lintr$
^codecov\.yml$
^codecov\.yml$
1 change: 0 additions & 1 deletion .lintr
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,3 @@ linters: linters_with_tags(
todo_comment_linter = NULL,
function_argument_linter = NULL
)

1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ importFrom(dplyr,group_by_at)
importFrom(dplyr,group_split)
importFrom(dplyr,mutate)
importFrom(dplyr,n)
importFrom(dplyr,pick)
importFrom(dplyr,recode_factor)
importFrom(dplyr,rename)
importFrom(dplyr,sample_frac)
Expand Down
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
* Updated all GitHub Actions
* Updated `tidyverse` code usage to account for depreciation.
* Updated `pkgdown` site to use `pkgdown` 2.0.0.
* Updated all uses of `size` to `linewidth` to account for depreciation in `ggplot2` line geoms.

# EpiSoon 0.3.0

Expand Down
61 changes: 25 additions & 36 deletions R/plot.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
#' following variables:
#' - either `rt` or `cases`
#' - and `date`.
#' @param horizon_cutoff Numeric, defaults to NULL. Forecast horizon to plot up to.
#' @param obs_cutoff_at_forecast Logical defaults to `TRUE`. Should observations only be shown
#' @param horizon_cutoff Numeric, defaults to NULL. Forecast horizon to plot up
#' to.
#' @param obs_cutoff_at_forecast Logical defaults to `TRUE`. Should
#' observations only be shown
#' up to the date of the forecast.
#' @importFrom dplyr filter
#' @importFrom ggplot2 ggplot aes geom_line geom_ribbon scale_x_date labs
Expand Down Expand Up @@ -74,7 +76,9 @@ plot_forecast <- function(forecast = NULL,
ggplot2::geom_line(ggplot2::aes(y = bottom), alpha = 0.5) +
ggplot2::geom_line(ggplot2::aes(y = top), alpha = 0.5) +
ggplot2::geom_ribbon(ggplot2::aes(ymin = bottom, ymax = top), alpha = 0.1) +
ggplot2::geom_ribbon(ggplot2::aes(ymin = lower, ymax = upper), alpha = 0.2) +
ggplot2::geom_ribbon(
ggplot2::aes(ymin = lower, ymax = upper), alpha = 0.2
) +
ggplot2::geom_point(
data = observations,
ggplot2::aes(y = y), size = 1.1,
Expand Down Expand Up @@ -161,7 +165,8 @@ plot_forecast_evaluation <- function(forecasts = NULL,
#' @return A dataframe of summarised scores in a tidy format.
#' @export
plot_scores <- function() {
## Some thought required here as to what the best - most general purpose scoring plot would be.
## Some thought required here as to what the best - most general purpose
## scoring plot would be.
}


Expand All @@ -186,6 +191,7 @@ plot_scores <- function() {
#' @return A named list of `ggplot2` objects
#' @export
#' @importFrom dplyr mutate bind_rows filter group_by ungroup recode_factor
#' pick
#' @importFrom cowplot plot_grid theme_cowplot panel_border
#' @importFrom purrr map_dfr
#' @importFrom lubridate days
Expand Down Expand Up @@ -249,11 +255,9 @@ plot_compare_timeseries <- function(compare_timeseries_output,
score = c(
"Bias",
"CRPS",
"Sharpness",
"Calibration",
"Median",
"IQR",
"CI"
"Dispersion",
"AE (median)",
"SE (mean)"
)) {
## Prepare plotting output
plot_output <- list()
Expand All @@ -262,10 +266,6 @@ plot_compare_timeseries <- function(compare_timeseries_output,
rt_scores <- compare_timeseries_output$rt_scores
case_scores <- compare_timeseries_output$case_scores

## Fix attributes of calibration to remove warnings
names(rt_scores$calibration) <- NULL
names(case_scores$calibration) <- NULL

## Identify maximum available horizon
max_horizon <- max(rt_scores$horizon)

Expand Down Expand Up @@ -339,30 +339,25 @@ plot_compare_timeseries <- function(compare_timeseries_output,

adjust_score <- function(df, group_var) {
df_update <- df %>%
dplyr::group_by(score, .dots = group_var) %>%
dplyr::group_by(score, dplyr::pick({{ group_var }})) %>%
dplyr::mutate(upper_min = 10 * min(upper)) %>%
dplyr::ungroup() # %>%
df_update <-
df_update[which(df_update$upper <= df_update$upper_min |
df_update$score %in% c("bias", "calibration")), ] %>%
# dplyr::filter(upper <= upper_min |
# score %in% c("bias", "calibration")) %>%
df_update$score %in% "bias"), ] %>%
dplyr::ungroup() %>%
dplyr::filter(!score %in% c("logs", "dss")) %>%
dplyr::filter(!score %in% c("log_score", "dss")) %>%
dplyr::mutate(score = score %>%
factor(levels = c(
"crps", "calibration",
"sharpness", "bias",
"median", "iqr", "ci"
"crps", "bias",
"ae_median", "mad", "se_mean"
)) %>%
dplyr::recode_factor(
"crps" = "CRPS",
"calibration" = "Calibration",
"sharpness" = "Sharpness",
"bias" = "Bias",
"median" = "Median",
"iqr" = "IQR",
"ci" = "CI"
"ae_median" = "AE (median)",
"mad" = "Dispersion",
"se_mean" = "SE (mean)"
))
return(df_update)
}
Expand All @@ -379,12 +374,6 @@ summarise_scores_by_horizon <- function(scores) {
EpiSoon::summarise_scores() %>%
dplyr::mutate(horizon = "8 -- 14")

# score_14_plus <- scores %>%
# dplyr::filter(horizon > 14) %>%
# EpiSoon::summarise_scores() %>%
# dplyr::mutate(horizon = "14+")
#

scores <- score_7 %>%
dplyr::bind_rows(score_14) %>%
# dplyr::bind_rows(score_14_plus) %>%
Expand Down Expand Up @@ -430,7 +419,7 @@ plot_region_score <- function(scores, label = NULL) {
position = ggplot2::position_dodge(width = 1)
) +
ggplot2::geom_linerange(ggplot2::aes(ymin = lower, ymax = upper),
alpha = 0.4, size = 1.1,
alpha = 0.4, linewidth = 1.1,
position =
ggplot2::position_dodge(width = 1)
) +
Expand All @@ -452,11 +441,11 @@ plot_internal <- function(df, label = NULL) {
x = horizon, y = mean, col = model,
group = model
)) +
ggplot2::geom_line(size = 1.2, alpha = 0.6) +
ggplot2::geom_line(linewidth = 1.2, alpha = 0.6) +
ggplot2::geom_point(size = 2) +
ggplot2::geom_point(ggplot2::aes(y = median), shape = 2, size = 2) +
ggplot2::geom_linerange(ggplot2::aes(ymin = lower, ymax = upper),
alpha = 0.4, size = 1.5,
alpha = 0.4, linewidth = 1.5,
position =
ggplot2::position_dodge(width = 3)
) +
Expand Down Expand Up @@ -486,7 +475,7 @@ summary_plot <- function(scores, target_score) {
position = ggplot2::position_dodge(width = 1)
) +
ggplot2::geom_linerange(ggplot2::aes(ymin = lower, ymax = upper),
alpha = 0.4, size = 1.5,
alpha = 0.4, linewidth = 1.5,
position =
ggplot2::position_dodge(width = 1)
) +
Expand Down
2 changes: 1 addition & 1 deletion man/plot_compare_timeseries.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 4 additions & 2 deletions man/plot_forecast.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

38 changes: 38 additions & 0 deletions tests/testthat/test_plot_compare_timeseries.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@

test_that("plot_compare_timeseries produces expected output", {
obs_rts <- EpiSoon::example_obs_rts %>%
dplyr::mutate(timeseries = "Region 1") %>%
dplyr::bind_rows(EpiSoon::example_obs_rts %>%
dplyr::mutate(timeseries = "Region 2"))

obs_cases <- EpiSoon::example_obs_cases %>%
dplyr::mutate(timeseries = "Region 1") %>%
dplyr::bind_rows(EpiSoon::example_obs_cases %>%
dplyr::mutate(timeseries = "Region 2"))

models <- list(
"AR 3" = function(...) {
EpiSoon::bsts_model(model = function(ss, y) {
bsts::AddAr(ss, y = y, lags = 3)
}, ...)
},
"Semi-local linear trend" = function(...) {
EpiSoon::bsts_model(model = function(ss, y) {
bsts::AddSemilocalLinearTrend(ss, y = y)
}, ...)
}
)

forecast_eval <-
compare_timeseries(obs_rts, obs_cases, models,
horizon = 10, samples = 10,
serial_interval = EpiSoon::example_serial_interval
)

p <- plot_compare_timeseries(forecast_eval)
expect_length(p, 8)
purrr::walk(p, function(x) {
expect_s3_class(x, c("gg", "ggplot"))
expect_silent(plot(x))
})
})

0 comments on commit 40731b4

Please sign in to comment.