-
Notifications
You must be signed in to change notification settings - Fork 3.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[R-package] Use type
argument to control prediction types
#5133
Changes from 3 commits
4cc29b9
7195b2e
398f41a
ed36686
d33b7a8
8b113c1
4838e29
84ec681
9e0b83d
c526eb3
acdd715
0288e6e
4474a2d
3c0dc29
59b9776
2d2bb38
43f4a79
a42b644
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||
---|---|---|---|---|---|---|---|---|
|
@@ -713,6 +713,23 @@ Booster <- R6::R6Class( | |||||||
#' @param object Object of class \code{lgb.Booster} | ||||||||
#' @param newdata a \code{matrix} object, a \code{dgCMatrix} object or | ||||||||
#' a character representing a path to a text file (CSV, TSV, or LibSVM) | ||||||||
#' @param type Type of prediction to output. Allowed types are:\itemize{ | ||||||||
#' \item \code{"link"}: will output the predicted score according to the objective function being | ||||||||
#' optimized (depending on the link function that the objective uses), after applying any necessary | ||||||||
#' transformations - for example, for \code{objective="binary"}, it will output class probabilities. | ||||||||
#' \item \code{"response"}: for classification objectives, will output the class with the highest predicted | ||||||||
#' probability. For other objectives, will output the same as "link". | ||||||||
#' \item \code{"raw"}: will output the non-transformed numbers (sum of predictions from boosting iterations' | ||||||||
#' results) from which the "link" number is produced for a given objective function - for example, for | ||||||||
#' \code{objective="binary"}, this corresponds to log-odds. For many objectives such as "regression", | ||||||||
#' since no transformation is applied, the output will be the same as for "link". | ||||||||
#' \item \code{"leaf"}: will output the index of the terminal node / leaf at which each observations falls | ||||||||
#' in each tree in the model, outputted as as integers, with one column per tree. | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
#' \item \code{"contrib"}: will return the per-feature contributions for each prediction, including an | ||||||||
#' intercept (each feature will produce one column). If there are multiple classes, each class will | ||||||||
#' have separate feature contributions (thus the number of columns is feaures+1 multiplied by the | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||||
#' number of classes). | ||||||||
#' } | ||||||||
#' @param start_iteration int or None, optional (default=None) | ||||||||
#' Start index of the iteration to predict. | ||||||||
#' If None or <= 0, starts from the first iteration. | ||||||||
|
@@ -721,22 +738,19 @@ Booster <- R6::R6Class( | |||||||
#' If None, if the best iteration exists and start_iteration is None or <= 0, the | ||||||||
#' best iteration is used; otherwise, all iterations from start_iteration are used. | ||||||||
#' If <= 0, all iterations from start_iteration are used (no limits). | ||||||||
#' @param rawscore whether the prediction should be returned in the for of original untransformed | ||||||||
#' sum of predictions from boosting iterations' results. E.g., setting \code{rawscore=TRUE} | ||||||||
#' for logistic regression would result in predictions for log-odds instead of probabilities. | ||||||||
#' @param predleaf whether predict leaf index instead. | ||||||||
#' @param predcontrib return per-feature contributions for each record. | ||||||||
#' @param header only used for prediction for text file. True if text file has header | ||||||||
#' @param params a list of additional named parameters. See | ||||||||
#' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{ | ||||||||
#' the "Predict Parameters" section of the documentation} for a list of parameters and | ||||||||
#' valid values. | ||||||||
#' @param ... ignored | ||||||||
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}. | ||||||||
#' For multiclass classification, it returns a matrix of dimensions \code{(nrows(data), num_class)}. | ||||||||
#' @return For prediction types that are meant to always return one output per observation (e.g. when predicting | ||||||||
#' \code{type="link"} on a binary classification or regression objective), will return a vector with one | ||||||||
#' row per observation in \code{newdata}. | ||||||||
#' | ||||||||
#' When passing \code{predleaf=TRUE} or \code{predcontrib=TRUE}, the output will always be | ||||||||
#' returned as a matrix. | ||||||||
#' For prediction types that are meant to return more than one output per observation (e.g. when predicting | ||||||||
#' \code{type="link"} on a multi-class objective, or when predicting \code{type="leaf"}, regardless of | ||||||||
#' objective), will return a matrix with one row per observation in \code{newdata} and one column per output. | ||||||||
#' | ||||||||
#' @examples | ||||||||
#' \donttest{ | ||||||||
|
@@ -770,15 +784,13 @@ Booster <- R6::R6Class( | |||||||
#' ) | ||||||||
#' ) | ||||||||
#' } | ||||||||
#' @importFrom utils modifyList | ||||||||
#' @importFrom utils modifyList head | ||||||||
#' @export | ||||||||
predict.lgb.Booster <- function(object, | ||||||||
newdata, | ||||||||
type = c("link", "response", "raw", "leaf", "contrib"), | ||||||||
start_iteration = NULL, | ||||||||
num_iteration = NULL, | ||||||||
rawscore = FALSE, | ||||||||
predleaf = FALSE, | ||||||||
Comment on lines
-809
to
-810
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Down in the section where this function catches arguments that fall into LightGBM/R-package/R/lgb.Booster.R Lines 822 to 824 in f715645
Please add the following: if (isTRUE(additional_params[["rawscore"]])) {
stop("Argument 'rawscore' is no longer supported. Use type = 'raw' instead.")
}
if (isTRUE(additional_params[["predleaf"]])) {
stop("Argument 'predleaf' is no longer supported. Use type = 'leaf' instead.")
}
if (isTRUE(additional_params[["predcontrib"]])) {
stop("Argument 'predcontrib' is no longer supported. Use type = 'contrib' instead.")
} I'm ok with breaking users' code in the next release in exchange for making the package's interface more compatible with other packages for modeling in R, but I think we should provide specific, actionable error messages when possible to reduce the effort required for affected users to alter their code. |
||||||||
predcontrib = FALSE, | ||||||||
header = FALSE, | ||||||||
params = list(), | ||||||||
...) { | ||||||||
|
@@ -799,18 +811,36 @@ predict.lgb.Booster <- function(object, | |||||||
)) | ||||||||
} | ||||||||
|
||||||||
return( | ||||||||
object$predict( | ||||||||
data = newdata | ||||||||
, start_iteration = start_iteration | ||||||||
, num_iteration = num_iteration | ||||||||
, rawscore = rawscore | ||||||||
, predleaf = predleaf | ||||||||
, predcontrib = predcontrib | ||||||||
, header = header | ||||||||
, params = params | ||||||||
) | ||||||||
type <- head(type, 1L) | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What is the purpose of having There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The purpose is to have the allowed values visible in the function signature so that they are easily seen by the user and easy to autocomplete, in the same way as for example base R's There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok thank you. This project has documentation for the purpose of describing which values are supported, and I believe the pattern of having a default value which is not the value that will be directly used will be confusing for both developers and users of the package. Please remove this and set the default to |
||||||||
rawscore <- FALSE | ||||||||
predleaf <- FALSE | ||||||||
predcontrib <- FALSE | ||||||||
if (type == "raw") { | ||||||||
rawscore <- TRUE | ||||||||
} else if (type == "leaf") { | ||||||||
predleaf <- TRUE | ||||||||
} else if (type == "contrib") { | ||||||||
predcontrib <- TRUE | ||||||||
} | ||||||||
|
||||||||
pred <- object$predict( | ||||||||
data = newdata | ||||||||
, start_iteration = start_iteration | ||||||||
, num_iteration = num_iteration | ||||||||
, rawscore = rawscore | ||||||||
, predleaf = predleaf | ||||||||
, predcontrib = predcontrib | ||||||||
, header = header | ||||||||
, params = params | ||||||||
) | ||||||||
if (type == "response") { | ||||||||
if (object$params$objective == "binary") { | ||||||||
pred <- as.integer(pred >= 0.5) | ||||||||
} else if (object$params$objective %in% c("multiclass", "multiclassova")) { | ||||||||
pred <- max.col(pred) - 1L | ||||||||
} | ||||||||
} | ||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since this is new behavior being added to the package, please add unit tests confirming that it works as expected. |
||||||||
return(pred) | ||||||||
} | ||||||||
|
||||||||
#' @name print.lgb.Booster | ||||||||
|
Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you please add a note to this documentation that when choosing
"link"
and"response"
, if you're using a custom objective function they'll be ignored and"raw"
predictions will be returned?On the Python side,
lightgbm
raises a warning in such situations.LightGBM/python-package/lightgbm/sklearn.py
Lines 1064 to 1067 in f715645
The R package should probably do that to, but that could be deferred to a later PR.