Skip to content

Commit

Permalink
[R-package] Accept factor labels and use their levels (#5341)
Browse files Browse the repository at this point in the history
  • Loading branch information
david-cortes authored Feb 14, 2023
1 parent 9713ff4 commit c676a7e
Show file tree
Hide file tree
Showing 9 changed files with 313 additions and 10 deletions.
2 changes: 1 addition & 1 deletion R-package/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -63,4 +63,4 @@ Imports:
utils
SystemRequirements:
C++11
RoxygenNote: 7.2.1
RoxygenNote: 7.2.3
24 changes: 24 additions & 0 deletions R-package/R/aliases.R
Original file line number Diff line number Diff line change
Expand Up @@ -78,3 +78,27 @@
)
)
}

.MULTICLASS_OBJECTIVES <- function() {
return(
c(
"multi_logloss"
, "multiclass"
, "softmax"
, "multiclassova"
, "multiclass_ova"
, "ova"
, "ovr"
)
)
}

.BINARY_OBJECTIVES <- function() {
return(
c(
"binary_logloss"
, "binary"
, "binary_error"
)
)
}
21 changes: 19 additions & 2 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Booster <- R6::R6Class(
best_score = NA_real_,
params = list(),
record_evals = list(),
data_processor = NULL,

# Finalize will free up the handles
finalize = function() {
Expand Down Expand Up @@ -837,6 +838,11 @@ Booster <- R6::R6Class(
#'
#' Note that, if using custom objectives, types "class" and "response" will not be available and will
#' default towards using "raw" instead.
#'
#' If the model was fit through function \link{lightgbm} and it was passed a factor as labels,
#' passing the prediction type through \code{params} instead of through this argument might
#' result in factor levels for classification objectives not being applied correctly to the
#' resulting output.
#' @param start_iteration int or None, optional (default=None)
#' Start index of the iteration to predict.
#' If None or <= 0, starts from the first iteration.
Expand Down Expand Up @@ -895,6 +901,11 @@ NULL
#' in the order "feature contributions for first class, feature contributions for second class, feature
#' contributions for third class, etc.".
#'
#' If the model was fit through function \link{lightgbm} and it was passed a factor as labels, predictions
#' returned from this function will retain the factor levels (either as values for \code{type="class"}, or
#' as column names for \code{type="response"} and \code{type="raw"} for multi-class objectives). Note that
#' passing the requested prediction type under \code{params} instead of through \code{type} might result in
#' the factor levels not being present in the output.
#' @examples
#' \donttest{
#' data(agaricus.train, package = "lightgbm")
Expand Down Expand Up @@ -996,12 +1007,18 @@ predict.lgb.Booster <- function(object,
, params = params
)
if (type == "class") {
if (object$params$objective == "binary") {
if (object$params$objective %in% .BINARY_OBJECTIVES()) {
pred <- as.integer(pred >= 0.5)
} else if (object$params$objective %in% c("multiclass", "multiclassova")) {
} else if (object$params$objective %in% .MULTICLASS_OBJECTIVES()) {
pred <- max.col(pred) - 1L
}
}
if (!is.null(object$data_processor)) {
pred <- object$data_processor$process_predictions(
pred = pred
, type = type
)
}
return(pred)
}

Expand Down
94 changes: 94 additions & 0 deletions R-package/R/lgb.DataProcessor.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
DataProcessor <- R6::R6Class(
classname = "lgb.DataProcessor",
public = list(
factor_levels = NULL,

process_label = function(label, objective, params) {

if (is.character(label)) {
label <- factor(label)
}

if (is.factor(label)) {

self$factor_levels <- levels(label)
if (length(self$factor_levels) <= 1L) {
stop("Labels to predict is a factor with <2 possible values.")
}

label <- as.numeric(label) - 1.0
out <- list(label = label)
if (length(self$factor_levels) == 2L) {
if (objective == "auto") {
objective <- "binary"
}
if (!(objective %in% .BINARY_OBJECTIVES())) {
stop("Two-level factors as labels only allowed for objective='binary' or objective='auto'.")
}
} else {
if (objective == "auto") {
objective <- "multiclass"
}
if (!(objective %in% .MULTICLASS_OBJECTIVES())) {
stop(
sprintf(
"Factors with >2 levels as labels only allowed for multi-class objectives. Got: %s (allowed: %s)"
, objective
, toString(.MULTICLASS_OBJECTIVES())
)
)
}
data_num_class <- length(self$factor_levels)
params <- lgb.check.wrapper_param(
main_param_name = "num_class"
, params = params
, alternative_kwarg_value = data_num_class
)
if (params[["num_class"]] != data_num_class) {
warning(
sprintf(
"Found num_class=%d in params, but 'label' is a factor with %d levels. 'num_class' will be ignored."
, params[["num_class"]]
, data_num_class
)
)
params$num_class <- data_num_class
}
}
out$objective <- objective
out$params <- params
return(out)

} else {

label <- as.numeric(label)
if (objective == "auto") {
objective <- "regression"
}
out <- list(
label = label
, objective = objective
, params = params
)
return(out)

}
},

process_predictions = function(pred, type) {
if (NROW(self$factor_levels)) {
if (type == "class") {
pred <- as.integer(pred) + 1L
attributes(pred)$levels <- self$factor_levels
attributes(pred)$class <- "factor"
} else if (type %in% c("response", "raw")) {
if (is.matrix(pred) && ncol(pred) == length(self$factor_levels)) {
colnames(pred) <- self$factor_levels
}
}
}

return(pred)
}
)
)
28 changes: 27 additions & 1 deletion R-package/R/lightgbm.R
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,15 @@ NULL
#' For a list of accepted objectives, see
#' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#objective}{
#' the "objective" item of the "Parameters" section of the documentation}.
#'
#' If passing \code{"auto"} and \code{data} is not of type \code{lgb.Dataset}, the objective will
#' be determined according to what is passed for \code{label}:\itemize{
#' \item If passing a factor with two variables, will use objective \code{"binary"}.
#' \item If passing a factor with more than two variables, will use objective \code{"multiclass"}
#' (note that parameter \code{num_class} in this case will also be determined automatically from
#' \code{label}).
#' \item Otherwise, will use objective \code{"regression"}.
#' }
#' @param init_score initial score is the base prediction lightgbm will boost from
#' @param num_threads Number of parallel threads to use. For best speed, this should be set to the number of
#' physical cores in the CPU - in a typical x86-64 machine, this corresponds to half the
Expand Down Expand Up @@ -149,7 +158,7 @@ lightgbm <- function(data,
init_model = NULL,
callbacks = list(),
serializable = TRUE,
objective = "regression",
objective = "auto",
init_score = NULL,
num_threads = NULL,
...) {
Expand All @@ -173,6 +182,22 @@ lightgbm <- function(data,
, alternative_kwarg_value = verbose
)

# Process factors as labels and auto-determine objective
if (!lgb.is.Dataset(data)) {
data_processor <- DataProcessor$new()
temp <- data_processor$process_label(
label = label
, objective = objective
, params = params
)
label <- temp$label
objective <- temp$objective
params <- temp$params
rm(temp)
} else {
data_processor <- NULL
}

# Set data to a temporary variable
dtrain <- data

Expand Down Expand Up @@ -204,6 +229,7 @@ lightgbm <- function(data,
what = lgb.train
, args = train_args
)
bst$data_processor <- data_processor

return(bst)
}
Expand Down
7 changes: 6 additions & 1 deletion R-package/man/lgb.configure_fast_predict.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 13 additions & 4 deletions R-package/man/lightgbm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 12 additions & 1 deletion R-package/man/predict.lgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit c676a7e

Please sign in to comment.