Skip to content

Commit

Permalink
WIP: enforce outcome must be a factor when classification is binary
Browse files Browse the repository at this point in the history
  • Loading branch information
kelly-sovacool committed Jun 10, 2023
1 parent e752c48 commit 3fe67a1
Show file tree
Hide file tree
Showing 12 changed files with 90 additions and 18 deletions.
58 changes: 51 additions & 7 deletions R/checks.R
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ check_outcome_column <- function(dataset, outcome_colname, check_values = TRUE,
#' \dontrun{
#' check_outcome_value(otu_small, "dx", "cancer")
#' }
check_outcome_value <- function(dataset, outcome_colname) {
check_outcome_value <- function(dataset, outcome_colname, pos_outcome = NULL) {
# check no NA's
outcomes_vec <- dataset %>% dplyr::pull(outcome_colname)
num_missing <- sum(is.na(outcomes_vec))
Expand All @@ -273,20 +273,17 @@ check_outcome_value <- function(dataset, outcome_colname) {
warning(paste0("Possible missing data in the output variable: ", num_empty, " empty value(s)."))
}

outcomes_all <- dataset %>%
dplyr::pull(outcome_colname)

# check if continuous outcome
isnum <- is.numeric(outcomes_all)
isnum <- is.numeric(outcomes_vec)
if (isnum) {
# check if it might actually be categorical
if (all(floor(outcomes_all) == outcomes_all)) {
if (all(floor(outcomes_vec) == outcomes_vec)) {
warning("Data is being considered numeric, but all outcome values are integers. If you meant to code your values as categorical, please use character values.")
}
}

# check binary and multiclass outcome
outcomes <- outcomes_all %>%
outcomes <- outcomes_vec %>%
unique()
num_outcomes <- length(outcomes)
if (num_outcomes < 2) {
Expand All @@ -299,6 +296,53 @@ check_outcome_value <- function(dataset, outcome_colname) {
}
}

#' Check or set outcome column to be a factor with `pos_class` as the first level
#'
#' @inheritParams run_ml
#'
#' @return dataset, with the outcome column as a factor
#' @keywords internal
#' @author Kelly Sovacool, \email{sovacool@@umich.edu}
#'
#' @examples
#' dat <- data.frame('dx' = c('a','b','a','b','b','a'), feat = 1:6)
#' dat %>% set_outcome_factor('dx', 'a')
#' dat %>% set_outcome_factor('dx', 'b')
set_outcome_factor <- function(dataset, outcome_colname, pos_class) {
relevel_outcome <- FALSE
outcomes_vctr <- dataset %>% dplyr::pull(outcome_colname)
# make sure it's either a factor or pos_class is set.
# the first factor level is used as the positive class by caret
if (!is.factor(outcomes_vctr)) {
if (is.null(pos_class)) {
stop(paste0("Either the outcome column `", outcome_colname,
"` must be a factor with the first factor level being the positive class,\n",
"or you must specify `pos_class`."))
}
relevel_outcome <- TRUE
} else {
first_lvl <- levels(outcomes_vctr)[1]
if (!is.null(pos_class) & pos_class != first_lvl) {
warning(paste0('`pos_class` is set, but it is not the first level in the outcome column. ',
'Releveling the outcome column to set ',
'`pos_class`=', pos_class, ' as the first level.'))
relevel_outcome <- TRUE
}
}
if (isTRUE(relevel_outcome)) {
if (!(pos_class %in% outcomes_vctr)) {
stop(paste0('pos_class `', pos_class,
'` not found in outcome column.'))
}
dataset[outcome_colname] <- factor(outcomes_vctr,
levels = unique(c(pos_class,
outcomes_vctr
))
)
}
return(dataset)
}

#' Check whether package(s) are installed
#'
#' @param ... names of packages to check
Expand Down
24 changes: 13 additions & 11 deletions R/performance.R
Original file line number Diff line number Diff line change
Expand Up @@ -116,26 +116,28 @@ get_perf_metric_name <- function(outcome_type) {
#' class_probs = TRUE
#' )
#' }
calc_perf_metrics <- function(test_data, trained_model, outcome_colname, perf_metric_function, class_probs) {
calc_perf_metrics <- function(test_data, trained_model, outcome_colname,
perf_metric_function, class_probs,
pos_class = NULL) {
pred_type <- "raw"
if (class_probs) pred_type <- "prob"
preds <- stats::predict(trained_model, test_data, type = pred_type)
obs <- test_data %>% dplyr::pull(outcome_colname)
if (class_probs) {
if (is.factor(test_data %>% dplyr::pull(outcome_colname))) {
uniq_obs <- test_data %>%
dplyr::pull(outcome_colname) %>%
levels()
if (is.factor(obs)) {
uniq_obs <- obs %>% levels()
} else {
uniq_obs <- unique(c(
test_data %>% dplyr::pull(outcome_colname),
as.character(trained_model$pred$obs)
))
uniq_obs <- unique(c(pos_class,
test_data %>% dplyr::pull(outcome_colname),
as.character(trained_model$pred$obs)
)
)
obs <- factor(test_data %>% dplyr::pull(outcome_colname), levels = uniq_obs)
}
obs <- factor(test_data %>% dplyr::pull(outcome_colname), levels = uniq_obs)
# TODO refactor this line
pred_class <- factor(names(preds)[apply(preds, 1, which.max)], levels = uniq_obs)
perf_met <- perf_metric_function(data.frame(obs = obs, pred = pred_class, preds), lev = uniq_obs)
} else {
obs <- test_data %>% dplyr::pull(outcome_colname)
perf_met <- perf_metric_function(data.frame(obs = obs, pred = preds))
}
return(perf_met)
Expand Down
11 changes: 11 additions & 0 deletions R/run_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@
#' - xgbTree: xgboost
#' @param outcome_colname Column name as a string of the outcome variable
#' (default `NULL`; the first column will be chosen automatically).
#' @param pos_class The positive class, i.e. which level of `outcome_colname` is
#' the event of interest. If the outcome is binary, either the
#' `outcome_colname` must be a factor with the first level being the positive
#' class, or `pos_class` must be set. (default: `NULL`).
#' @param hyperparameters Dataframe of hyperparameters
#' (default `NULL`; sensible defaults will be chosen automatically).
#' @param seed Random seed (default: `NA`).
Expand Down Expand Up @@ -131,6 +135,7 @@ run_ml <-
function(dataset,
method,
outcome_colname = NULL,
pos_class = NULL,
hyperparameters = NULL,
find_feature_importance = FALSE,
calculate_performance = TRUE,
Expand Down Expand Up @@ -216,6 +221,10 @@ run_ml <-

outcome_type <- get_outcome_type(outcomes_vctr)
class_probs <- outcome_type != "continuous"
if (outcome_type == 'binary') {
# enforce factor levels
dataset <- dataset %>% set_outcome_factor(outcome_colname, pos_class)
}

if (is.null(perf_metric_function)) {
perf_metric_function <- get_perf_metric_fn(outcome_type)
Expand Down Expand Up @@ -254,6 +263,8 @@ run_ml <-
if (!is.na(seed)) {
set.seed(seed)
}
# verify that correct outcome level got used
trained_model_caret$levels[1]

if (calculate_performance) {
performance_tbl <- get_performance_tbl(
Expand Down
5 changes: 5 additions & 0 deletions data-raw/otu_mini_bin.R
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ otu_mini_group <- c(
otu_mini_bin_results_glmnet <- mikropml::run_ml(otu_mini_bin, # use built-in hyperparams
"glmnet",
outcome_colname = "dx",
pos_class = 'cancer',
find_feature_importance = FALSE,
seed = 2019,
cv_times = 2
Expand Down Expand Up @@ -77,6 +78,7 @@ use_data(otu_mini_cv, overwrite = TRUE)
otu_mini_bin_results_rf <- mikropml::run_ml(otu_mini_bin,
"rf",
outcome_colname = "dx",
pos_class = 'cancer',
find_feature_importance = TRUE,
seed = 2019,
cv_times = 2,
Expand All @@ -87,6 +89,7 @@ use_data(otu_mini_bin_results_rf, overwrite = TRUE)
otu_mini_bin_results_svmRadial <- mikropml::run_ml(otu_mini_bin,
"svmRadial",
outcome_colname = "dx",
pos_class = 'cancer',
find_feature_importance = FALSE,
seed = 2019,
cv_times = 2
Expand All @@ -96,6 +99,7 @@ use_data(otu_mini_bin_results_svmRadial, overwrite = TRUE)
otu_mini_bin_results_xgbTree <- mikropml::run_ml(otu_mini_bin,
"xgbTree",
outcome_colname = "dx",
pos_class = 'cancer',
find_feature_importance = FALSE,
seed = 2019,
cv_times = 2
Expand All @@ -105,6 +109,7 @@ use_data(otu_mini_bin_results_xgbTree, overwrite = TRUE)
otu_mini_bin_results_rpart2 <- mikropml::run_ml(otu_mini_bin,
"rpart2",
outcome_colname = "dx",
pos_class = 'cancer',
find_feature_importance = FALSE,
seed = 2019,
cv_times = 2
Expand Down
Binary file modified data/otu_mini_bin.rda
Binary file not shown.
Binary file modified data/otu_mini_bin_results_glmnet.rda
Binary file not shown.
Binary file modified data/otu_mini_bin_results_rf.rda
Binary file not shown.
Binary file modified data/otu_mini_bin_results_rpart2.rda
Binary file not shown.
Binary file modified data/otu_mini_bin_results_svmRadial.rda
Binary file not shown.
Binary file modified data/otu_mini_bin_results_xgbTree.rda
Binary file not shown.
Binary file modified data/otu_mini_cv.rda
Binary file not shown.
10 changes: 10 additions & 0 deletions tests/testthat/test-run_ml.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ test_that("run_ml works for logistic regression", {
# use built-in hyperparameters
"glmnet",
outcome_colname = "dx",
pos_class = 'cancer',
find_feature_importance = FALSE,
seed = 2019,
cv_times = 2
Expand Down Expand Up @@ -84,6 +85,7 @@ test_that("run_ml works for random forest with grouping & feature importance", {
mikropml::run_ml(otu_mini_bin,
"rf",
outcome_colname = "dx",
pos_class = 'cancer',
find_feature_importance = TRUE,
seed = 2019,
cv_times = 2,
Expand All @@ -102,6 +104,7 @@ test_that("run_ml works for svmRadial", {
mikropml::run_ml(otu_mini_bin,
"svmRadial",
outcome_colname = "dx",
pos_class = 'cancer',
find_feature_importance = FALSE,
seed = 2019,
cv_times = 2
Expand All @@ -120,6 +123,7 @@ test_that("run_ml works for xgbTree", {
otu_mini_bin,
"xgbTree",
outcome_colname = "dx",
pos_class = 'cancer',
find_feature_importance = FALSE,
seed = 2019,
cv_times = 2
Expand All @@ -137,6 +141,7 @@ test_that("run_ml works for rpart2", {
mikropml::run_ml(otu_mini_bin,
"rpart2",
outcome_colname = "dx",
pos_class = 'cancer',
find_feature_importance = FALSE,
seed = 2019,
cv_times = 2
Expand Down Expand Up @@ -210,6 +215,7 @@ test_that("run_ml uses custom training indices when provided", {
expect_warning(
results_custom_train <- run_ml(otu_mini_bin,
"glmnet",
pos_class = 'cancer',
kfold = 2,
cv_times = 5,
training_frac = training_rows,
Expand All @@ -230,6 +236,7 @@ test_that("run_ml uses custom group partitions", {
expect_message(
results_grp_part <- run_ml(otu_mini_bin,
"glmnet",
pos_class = 'cancer',
cv_times = 2,
training_frac = 0.8,
groups = grps,
Expand All @@ -255,6 +262,7 @@ test_that("run_ml catches bad training_frac values", {
run_ml(otu_mini_bin,
"glmnet",
outcome_colname = "dx",
pos_class = 'cancer',
training_frac = 0
),
"`training_frac` must be a numeric between 0 and 1."
Expand All @@ -263,6 +271,7 @@ test_that("run_ml catches bad training_frac values", {
run_ml(otu_mini_bin,
"glmnet",
outcome_colname = "dx",
pos_class = 'cancer',
training_frac = 1
),
"`training_frac` must be a numeric between 0 and 1."
Expand Down Expand Up @@ -298,6 +307,7 @@ test_that("models use case weights when provided", {
results_custom_train <- run_ml(
otu_mini_bin,
"glmnet",
pos_class = 'cancer',
kfold = 2,
cv_times = 5,
training_frac = train_weights %>% pull(row_num),
Expand Down

0 comments on commit 3fe67a1

Please sign in to comment.