Skip to content

Commit

Permalink
[R-package] respect aliases for objective and metric and lgb.train() …
Browse files Browse the repository at this point in the history
…and lgb.cv() (#4913)

* [R-package] respect aliases for objective and metric

* move eval code closer to eval processing

* remove unnecessary diff

* Update R-package/tests/testthat/test_basic.R
  • Loading branch information
jameslamb authored Jan 5, 2022
1 parent af5b40e commit ac821f0
Show file tree
Hide file tree
Showing 3 changed files with 123 additions and 14 deletions.
22 changes: 15 additions & 7 deletions R-package/R/lgb.cv.R
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,6 @@ lgb.cv <- function(params = list()
data <- lgb.Dataset(data = data, label = label)
}

# Setup temporary variables
params <- lgb.check.obj(params = params, obj = obj)
params <- lgb.check.eval(params = params, eval = eval)
fobj <- NULL
eval_functions <- list(NULL)

# set some parameters, resolving the way they were passed in with other parameters
# in `params`.
# this ensures that the model stored with Booster$save() correctly represents
Expand All @@ -125,14 +119,26 @@ lgb.cv <- function(params = list()
, params = params
, alternative_kwarg_value = nrounds
)
params <- lgb.check.wrapper_param(
main_param_name = "metric"
, params = params
, alternative_kwarg_value = NULL
)
params <- lgb.check.wrapper_param(
main_param_name = "objective"
, params = params
, alternative_kwarg_value = NULL
)
params <- lgb.check.wrapper_param(
main_param_name = "early_stopping_round"
, params = params
, alternative_kwarg_value = early_stopping_rounds
)
early_stopping_rounds <- params[["early_stopping_round"]]

# Check for objective (function or not)
# extract any function objects passed for objective or metric
params <- lgb.check.obj(params = params, obj = obj)
fobj <- NULL
if (is.function(params$objective)) {
fobj <- params$objective
params$objective <- "NONE"
Expand All @@ -142,6 +148,8 @@ lgb.cv <- function(params = list()
# (for backwards compatibility). If it is a list of functions, store
# all of them. This makes it possible to pass any mix of strings like "auc"
# and custom functions to eval
params <- lgb.check.eval(params = params, eval = eval)
eval_functions <- list(NULL)
if (is.function(eval)) {
eval_functions <- list(eval)
}
Expand Down
22 changes: 15 additions & 7 deletions R-package/R/lgb.train.R
Original file line number Diff line number Diff line change
Expand Up @@ -73,12 +73,6 @@ lgb.train <- function(params = list(),
}
}

# Setup temporary variables
params <- lgb.check.obj(params = params, obj = obj)
params <- lgb.check.eval(params = params, eval = eval)
fobj <- NULL
eval_functions <- list(NULL)

# set some parameters, resolving the way they were passed in with other parameters
# in `params`.
# this ensures that the model stored with Booster$save() correctly represents
Expand All @@ -93,14 +87,26 @@ lgb.train <- function(params = list(),
, params = params
, alternative_kwarg_value = nrounds
)
params <- lgb.check.wrapper_param(
main_param_name = "metric"
, params = params
, alternative_kwarg_value = NULL
)
params <- lgb.check.wrapper_param(
main_param_name = "objective"
, params = params
, alternative_kwarg_value = NULL
)
params <- lgb.check.wrapper_param(
main_param_name = "early_stopping_round"
, params = params
, alternative_kwarg_value = early_stopping_rounds
)
early_stopping_rounds <- params[["early_stopping_round"]]

# Check for objective (function or not)
# extract any function objects passed for objective or metric
params <- lgb.check.obj(params = params, obj = obj)
fobj <- NULL
if (is.function(params$objective)) {
fobj <- params$objective
params$objective <- "NONE"
Expand All @@ -110,6 +116,8 @@ lgb.train <- function(params = list(),
# (for backwards compatibility). If it is a list of functions, store
# all of them. This makes it possible to pass any mix of strings like "auc"
# and custom functions to eval
params <- lgb.check.eval(params = params, eval = eval)
eval_functions <- list(NULL)
if (is.function(eval)) {
eval_functions <- list(eval)
}
Expand Down
93 changes: 93 additions & 0 deletions R-package/tests/testthat/test_basic.R
Original file line number Diff line number Diff line change
Expand Up @@ -547,6 +547,52 @@ test_that("lgb.cv() respects showsd argument", {
expect_identical(evals_no_showsd[["eval_err"]], list())
})

test_that("lgb.cv() respects parameter aliases for objective", {
nrounds <- 3L
nfold <- 4L
dtrain <- lgb.Dataset(
data = train$data
, label = train$label
)
cv_bst <- lgb.cv(
data = dtrain
, params = list(
num_leaves = 5L
, application = "binary"
, num_iterations = nrounds
)
, nfold = nfold
)
expect_equal(cv_bst$best_iter, nrounds)
expect_named(cv_bst$record_evals[["valid"]], "binary_logloss")
expect_length(cv_bst$record_evals[["valid"]][["binary_logloss"]][["eval"]], nrounds)
expect_length(cv_bst$boosters, nfold)
})

test_that("lgb.cv() respects parameter aliases for metric", {
nrounds <- 3L
nfold <- 4L
dtrain <- lgb.Dataset(
data = train$data
, label = train$label
)
cv_bst <- lgb.cv(
data = dtrain
, params = list(
num_leaves = 5L
, objective = "binary"
, num_iterations = nrounds
, metric_types = c("auc", "binary_logloss")
)
, nfold = nfold
)
expect_equal(cv_bst$best_iter, nrounds)
expect_named(cv_bst$record_evals[["valid"]], c("auc", "binary_logloss"))
expect_length(cv_bst$record_evals[["valid"]][["binary_logloss"]][["eval"]], nrounds)
expect_length(cv_bst$record_evals[["valid"]][["auc"]][["eval"]], nrounds)
expect_length(cv_bst$boosters, nfold)
})

test_that("lgb.cv() respects eval_train_metric argument", {
dtrain <- lgb.Dataset(train$data, label = train$label)
params <- list(
Expand Down Expand Up @@ -616,6 +662,53 @@ test_that("lgb.train() works as expected with multiple eval metrics", {
)
})

test_that("lgb.train() respects parameter aliases for objective", {
nrounds <- 3L
dtrain <- lgb.Dataset(
data = train$data
, label = train$label
)
bst <- lgb.train(
data = dtrain
, params = list(
num_leaves = 5L
, application = "binary"
, num_iterations = nrounds
)
, valids = list(
"the_training_data" = dtrain
)
)
expect_named(bst$record_evals[["the_training_data"]], "binary_logloss")
expect_length(bst$record_evals[["the_training_data"]][["binary_logloss"]][["eval"]], nrounds)
expect_equal(bst$params[["objective"]], "binary")
})

test_that("lgb.train() respects parameter aliases for metric", {
nrounds <- 3L
dtrain <- lgb.Dataset(
data = train$data
, label = train$label
)
bst <- lgb.train(
data = dtrain
, params = list(
num_leaves = 5L
, objective = "binary"
, num_iterations = nrounds
, metric_types = c("auc", "binary_logloss")
)
, valids = list(
"train" = dtrain
)
)
record_results <- bst$record_evals[["train"]]
expect_equal(sort(names(record_results)), c("auc", "binary_logloss"))
expect_length(record_results[["auc"]][["eval"]], nrounds)
expect_length(record_results[["binary_logloss"]][["eval"]], nrounds)
expect_equal(bst$params[["metric"]], list("auc", "binary_logloss"))
})

test_that("lgb.train() rejects negative or 0 value passed to nrounds", {
dtrain <- lgb.Dataset(train$data, label = train$label)
params <- list(
Expand Down

0 comments on commit ac821f0

Please sign in to comment.