Skip to content

Commit

Permalink
Tidying code diablo_plot_tune
Browse files Browse the repository at this point in the history
  • Loading branch information
oliviaAB committed Mar 12, 2024
1 parent 99e92cb commit 95fe9e3
Showing 1 changed file with 31 additions and 11 deletions.
42 changes: 31 additions & 11 deletions R/diablo.R
Original file line number Diff line number Diff line change
Expand Up @@ -432,11 +432,11 @@ diablo_tune <- function(mixomics_data, design_matrix, keepX_list = NULL, cpus =

#' Plots DIABLO tune results
#'
#' Displays the error rate of a DIABLO run cross-validation to estimate the optimal number of features to retain from
#' each dataset (\code{keepX}).
#' Displays the error rate of a DIABLO run cross-validation to estimate the
#' optimal number of features to retain from each dataset (\code{keepX}).
#'
#' @param tune_res The cross-validation results, computed with \code{\link{diablo_tune}}.
#' @return A \code{ggplot2} object.
#' @param tune_res The cross-validation results, computed with [diablo_tune()].
#' @return A `ggplot2` object.
#' @export
diablo_plot_tune <- function(tune_res) {
## For devtools::check()
Expand All @@ -449,12 +449,17 @@ diablo_plot_tune <- function(tune_res) {
error_label <- error_label[tune_res$measure]
datasets_name <- names(tune_res$choice.keepX)
comps <- colnames(tune_res$error.rate)
comps_colours <- RColorBrewer::brewer.pal(max(length(comps), 3), "Set2")[1:length(comps)]
comps_colours <- RColorBrewer::brewer.pal(max(length(comps), 3), "Set2")
comps_colours <- comps_colours[1:length(comps)]
names(comps_colours) <- comps

## Reading error sd
if (is.null(rownames(tune_res$error.rate.sd))) rownames(tune_res$error.rate.sd) <- rownames(tune_res$error.rate)
if (is.null(colnames(tune_res$error.rate.sd))) colnames(tune_res$error.rate.sd) <- colnames(tune_res$error.rate)
if (is.null(rownames(tune_res$error.rate.sd))) {
rownames(tune_res$error.rate.sd) <- rownames(tune_res$error.rate)
}
if (is.null(colnames(tune_res$error.rate.sd))) {
colnames(tune_res$error.rate.sd) <- colnames(tune_res$error.rate)
}

df_sd <- tibble::as_tibble(tune_res$error.rate.sd, rownames = "id") |>
tidyr::pivot_longer(
Expand All @@ -471,7 +476,13 @@ diablo_plot_tune <- function(tune_res) {
values_to = "error"
) |>
dplyr::left_join(df_sd, by = c("id", "comp")) |>
tidyr::separate(id, into = datasets_name, sep = "_", remove = FALSE, convert = TRUE) |>
tidyr::separate(
id,
into = datasets_name,
sep = "_",
remove = FALSE,
convert = TRUE
) |>
dplyr::mutate(
error_min = error - sd,
error_max = error + sd
Expand All @@ -498,7 +509,10 @@ diablo_plot_tune <- function(tune_res) {
ggplot2::ggplot(aes(x = dataset, y = id, fill = nb_retained_features)) +
ggplot2::geom_tile() +
ggplot2::scale_x_discrete(expand = c(0, 0)) +
ggplot2::scale_fill_viridis_c(option = "plasma", guide = ifelse(i == comps[1], "colourbar", "none")) +
ggplot2::scale_fill_viridis_c(
option = "plasma",
guide = ifelse(i == comps[1], "colourbar", "none")
) +
ggplot2::theme(
axis.text.y = ggplot2::element_blank(),
axis.ticks.y = ggplot2::element_blank(),
Expand All @@ -517,7 +531,9 @@ diablo_plot_tune <- function(tune_res) {
ggplot2::ggplot(aes(x = id)) +
ggplot2::geom_col(aes(y = error), fill = comps_colours[i], width = 1) +
ggplot2::coord_flip() +
ggplot2::scale_y_continuous(expand = ggplot2::expansion(mult = c(0, 0.05))) +
ggplot2::scale_y_continuous(
expand = ggplot2::expansion(mult = c(0, 0.05))
) +
ggplot2::theme_bw() +
ggplot2::theme(
axis.text.y = ggplot2::element_blank(),
Expand Down Expand Up @@ -547,7 +563,11 @@ diablo_plot_tune <- function(tune_res) {
if (i == comps[1]) common_legend <- ggpubr::get_legend(plot1)
}

ggpubr::ggarrange(plotlist = plots_list, legend.grob = common_legend, legend = "bottom")
ggpubr::ggarrange(
plotlist = plots_list,
legend.grob = common_legend,
legend = "bottom"
)
}

#' Formatted table with optimal keepX values
Expand Down

0 comments on commit 95fe9e3

Please sign in to comment.