Skip to content

Commit

Permalink
[R-package] Remove reshape argument in predict (#4971)
Browse files Browse the repository at this point in the history
* change prediction default to reshape=TRUE

* remove reshape argument

* comments
  • Loading branch information
david-cortes authored Apr 1, 2022
1 parent 33eb037 commit 248fbfa
Show file tree
Hide file tree
Showing 7 changed files with 73 additions and 45 deletions.
17 changes: 6 additions & 11 deletions R-package/R/lgb.Booster.R
Original file line number Diff line number Diff line change
Expand Up @@ -474,7 +474,6 @@ Booster <- R6::R6Class(
predleaf = FALSE,
predcontrib = FALSE,
header = FALSE,
reshape = FALSE,
params = list()) {

self$restore_handle()
Expand All @@ -501,7 +500,6 @@ Booster <- R6::R6Class(
, predleaf = predleaf
, predcontrib = predcontrib
, header = header
, reshape = reshape
)
)

Expand Down Expand Up @@ -729,20 +727,16 @@ Booster <- R6::R6Class(
#' @param predleaf whether predict leaf index instead.
#' @param predcontrib return per-feature contributions for each record.
#' @param header only used for prediction for text file. True if text file has header
#' @param reshape whether to reshape the vector of predictions to a matrix form when there are several
#' prediction outputs per case.
#' @param params a list of additional named parameters. See
#' \href{https://lightgbm.readthedocs.io/en/latest/Parameters.html#predict-parameters}{
#' the "Predict Parameters" section of the documentation} for a list of parameters and
#' valid values.
#' @param ... ignored
#' @return For regression or binary classification, it returns a vector of length \code{nrows(data)}.
#' For multiclass classification, either a \code{num_class * nrows(data)} vector or
#' a \code{(nrows(data), num_class)} dimension matrix is returned, depending on
#' the \code{reshape} value.
#' For multiclass classification, it returns a matrix of dimensions \code{(nrows(data), num_class)}.
#'
#' When \code{predleaf = TRUE}, the output is a matrix object with the
#' number of columns corresponding to the number of trees.
#' When passing \code{predleaf=TRUE} or \code{predcontrib=TRUE}, the output will always be
#' returned as a matrix.
#'
#' @examples
#' \donttest{
Expand Down Expand Up @@ -786,7 +780,6 @@ predict.lgb.Booster <- function(object,
predleaf = FALSE,
predcontrib = FALSE,
header = FALSE,
reshape = FALSE,
params = list(),
...) {

Expand All @@ -796,6 +789,9 @@ predict.lgb.Booster <- function(object,

additional_params <- list(...)
if (length(additional_params) > 0L) {
if ("reshape" %in% names(additional_params)) {
stop("'reshape' argument is no longer supported.")
}
warning(paste0(
"predict.lgb.Booster: Found the following passed through '...': "
, paste(names(additional_params), collapse = ", ")
Expand All @@ -812,7 +808,6 @@ predict.lgb.Booster <- function(object,
, predleaf = predleaf
, predcontrib = predcontrib
, header = header
, reshape = reshape
, params = params
)
)
Expand Down
1 change: 0 additions & 1 deletion R-package/R/lgb.Dataset.R
Original file line number Diff line number Diff line change
Expand Up @@ -295,7 +295,6 @@ Dataset <- R6::R6Class(
init_score <- private$predictor$predict(
data = private$raw_data
, rawscore = TRUE
, reshape = TRUE
)

# Not needed to transpose, for is col_marjor
Expand Down
16 changes: 2 additions & 14 deletions R-package/R/lgb.Predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,7 @@ Predictor <- R6::R6Class(
rawscore = FALSE,
predleaf = FALSE,
predcontrib = FALSE,
header = FALSE,
reshape = FALSE) {
header = FALSE) {

# Check if number of iterations is existing - if not, then set it to -1 (use all)
if (is.null(num_iteration)) {
Expand Down Expand Up @@ -215,23 +214,12 @@ Predictor <- R6::R6Class(
# Get number of cases per row
npred_per_case <- length(preds) / num_row


# Data reshaping

if (predleaf | predcontrib) {

# Predict leaves only, reshaping is mandatory
preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)

} else if (reshape && npred_per_case > 1L) {

# Predict with data reshaping
if (npred_per_case > 1L || predleaf || predcontrib) {
preds <- matrix(preds, ncol = npred_per_case, byrow = TRUE)

}

return(preds)

}

),
Expand Down
10 changes: 2 additions & 8 deletions R-package/demo/multiclass.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,21 +56,15 @@ model <- lgb.train(
# We can predict on test data, identical
my_preds <- predict(model, test[, 1L:4L])

# A (30x3) matrix with the predictions, use parameter reshape
# A (30x3) matrix with the predictions
# class1 class2 class3
# obs1 obs1 obs1
# obs2 obs2 obs2
# .... .... ....
my_preds <- predict(model, test[, 1L:4L], reshape = TRUE)
my_preds <- predict(model, test[, 1L:4L])

# We can also get the predicted scores before the Sigmoid/Softmax application
my_preds <- predict(model, test[, 1L:4L], rawscore = TRUE)

# Raw score predictions as matrix instead of vector
my_preds <- predict(model, test[, 1L:4L], rawscore = TRUE, reshape = TRUE)

# We can also get the leaf index
my_preds <- predict(model, test[, 1L:4L], predleaf = TRUE)

# Predict leaf index as matrix instead of vector
my_preds <- predict(model, test[, 1L:4L], predleaf = TRUE, reshape = TRUE)
4 changes: 2 additions & 2 deletions R-package/demo/multiclass_custom_objective.R
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ model_builtin <- lgb.train(
, obj = "multiclass"
)

preds_builtin <- predict(model_builtin, test[, 1L:4L], rawscore = TRUE, reshape = TRUE)
preds_builtin <- predict(model_builtin, test[, 1L:4L], rawscore = TRUE)
probs_builtin <- exp(preds_builtin) / rowSums(exp(preds_builtin))

# Method 2 of training with custom objective function
Expand Down Expand Up @@ -109,7 +109,7 @@ model_custom <- lgb.train(
, eval = custom_multiclass_metric
)

preds_custom <- predict(model_custom, test[, 1L:4L], rawscore = TRUE, reshape = TRUE)
preds_custom <- predict(model_custom, test[, 1L:4L], rawscore = TRUE)
probs_custom <- exp(preds_custom) / rowSums(exp(preds_custom))

# compare predictions
Expand Down
12 changes: 3 additions & 9 deletions R-package/man/predict.lgb.Booster.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

58 changes: 58 additions & 0 deletions R-package/tests/testthat/test_Predictor.R
Original file line number Diff line number Diff line change
Expand Up @@ -111,3 +111,61 @@ test_that("start_iteration works correctly", {
pred_leaf2 <- predict(bst, test$data, start_iteration = 0L, num_iteration = end_iter + 1L, predleaf = TRUE)
expect_equal(pred_leaf1, pred_leaf2)
})

test_that("predictions for regression and binary classification are returned as vectors", {
data(mtcars)
X <- as.matrix(mtcars[, -1L])
y <- as.numeric(mtcars[, 1L])
dtrain <- lgb.Dataset(X, label = y, params = list(max_bins = 5L))
model <- lgb.train(
data = dtrain
, obj = "regression"
, nrounds = 5L
, verbose = VERBOSITY
)
pred <- predict(model, X)
expect_true(is.vector(pred))
expect_equal(length(pred), nrow(X))
pred <- predict(model, X, rawscore = TRUE)
expect_true(is.vector(pred))
expect_equal(length(pred), nrow(X))

data(agaricus.train, package = "lightgbm")
X <- agaricus.train$data
y <- agaricus.train$label
dtrain <- lgb.Dataset(X, label = y)
model <- lgb.train(
data = dtrain
, obj = "binary"
, nrounds = 5L
, verbose = VERBOSITY
)
pred <- predict(model, X)
expect_true(is.vector(pred))
expect_equal(length(pred), nrow(X))
pred <- predict(model, X, rawscore = TRUE)
expect_true(is.vector(pred))
expect_equal(length(pred), nrow(X))
})

test_that("predictions for multiclass classification are returned as matrix", {
data(iris)
X <- as.matrix(iris[, -5L])
y <- as.numeric(iris$Species) - 1.0
dtrain <- lgb.Dataset(X, label = y)
model <- lgb.train(
data = dtrain
, obj = "multiclass"
, nrounds = 5L
, verbose = VERBOSITY
, params = list(num_class = 3L)
)
pred <- predict(model, X)
expect_true(is.matrix(pred))
expect_equal(nrow(pred), nrow(X))
expect_equal(ncol(pred), 3L)
pred <- predict(model, X, rawscore = TRUE)
expect_true(is.matrix(pred))
expect_equal(nrow(pred), nrow(X))
expect_equal(ncol(pred), 3L)
})

0 comments on commit 248fbfa

Please sign in to comment.