Skip to content

Commit

Permalink
[R-package] avoid unnecessary computation of std deviations in `lgb.c…
Browse files Browse the repository at this point in the history
…v()` (#4360)

* [R-package] avoid unnecessary computation of std deviations in lgb.cv()

* use expect_equal()
  • Loading branch information
jameslamb authored Jun 12, 2021
1 parent 4af4698 commit f0bca1a
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 4 deletions.
11 changes: 8 additions & 3 deletions R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ CVBooster <- R6::R6Class(
#' @param label Vector of labels, used if \code{data} is not an \code{\link{lgb.Dataset}}
#' @param weight vector of response values. If not NULL, will set to dataset
#' @param record Boolean, TRUE will record iteration message to \code{booster$record_evals}
#' @param showsd \code{boolean}, whether to show standard deviation of cross validation
#' @param showsd \code{boolean}, whether to show standard deviation of cross validation.
#' This parameter defaults to \code{TRUE}. Setting it to \code{FALSE} can lead to a
#' slight speedup by avoiding unnecessary computation.
#' @param stratified a \code{boolean} indicating whether sampling of folds should be stratified
#' by the values of outcome labels.
#' @param folds \code{list} provides a possibility to use a list of pre-defined CV folds
Expand Down Expand Up @@ -379,7 +381,10 @@ lgb.cv <- function(params = list()
})

# Prepare collection of evaluation results
merged_msg <- lgb.merge.cv.result(msg = msg)
merged_msg <- lgb.merge.cv.result(
msg = msg
, showsd = showsd
)

# Write evaluation result in environment
env$eval_list <- merged_msg$eval_list
Expand Down Expand Up @@ -576,7 +581,7 @@ lgb.stratified.folds <- function(y, k) {
return(out)
}

lgb.merge.cv.result <- function(msg, showsd = TRUE) {
lgb.merge.cv.result <- function(msg, showsd) {

# Get CV message length
if (length(msg) == 0L) {
Expand Down
4 changes: 3 additions & 1 deletion R-package/man/lgb.cv.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 33 additions & 0 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,39 @@ test_that("lgb.cv() fit on linearly-relatead data improves when using linear lea
expect_true(cv_bst_linear$best_score < cv_bst$best_score)
})

test_that("lgb.cv() respects showsd argument", {
dtrain <- lgb.Dataset(train$data, label = train$label)
params <- list(objective = "regression", metric = "l2")
nrounds <- 5L
set.seed(708L)
bst_showsd <- lgb.cv(
params = params
, data = dtrain
, nrounds = nrounds
, nfold = 3L
, min_data = 1L
, showsd = TRUE
)
evals_showsd <- bst_showsd$record_evals[["valid"]][["l2"]]
set.seed(708L)
bst_no_showsd <- lgb.cv(
params = params
, data = dtrain
, nrounds = nrounds
, nfold = 3L
, min_data = 1L
, showsd = FALSE
)
evals_no_showsd <- bst_no_showsd$record_evals[["valid"]][["l2"]]
expect_equal(
evals_showsd[["eval"]]
, evals_no_showsd[["eval"]]
)
expect_is(evals_showsd[["eval_err"]], "list")
expect_equal(length(evals_showsd[["eval_err"]]), nrounds)
expect_identical(evals_no_showsd[["eval_err"]], list())
})

context("lgb.train()")

test_that("lgb.train() works as expected with multiple eval metrics", {
Expand Down

0 comments on commit f0bca1a

Please sign in to comment.