From ac821f0c963c8a901b9983b4a53ddcbeedd55ad3 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 4 Jan 2022 19:53:23 -0600 Subject: [PATCH] [R-package] respect aliases for objective and metric and lgb.train() and lgb.cv() (#4913) * [R-package] respect aliases for objective and metric * move eval code closer to eval processing * remove unnecessary diff * Update R-package/tests/testthat/test_basic.R --- R-package/R/lgb.cv.R | 22 +++++-- R-package/R/lgb.train.R | 22 +++++-- R-package/tests/testthat/test_basic.R | 93 +++++++++++++++++++++++++++ 3 files changed, 123 insertions(+), 14 deletions(-) diff --git a/R-package/R/lgb.cv.R b/R-package/R/lgb.cv.R index 025c2e56f10f..6ae0e01df306 100644 --- a/R-package/R/lgb.cv.R +++ b/R-package/R/lgb.cv.R @@ -105,12 +105,6 @@ lgb.cv <- function(params = list() data <- lgb.Dataset(data = data, label = label) } - # Setup temporary variables - params <- lgb.check.obj(params = params, obj = obj) - params <- lgb.check.eval(params = params, eval = eval) - fobj <- NULL - eval_functions <- list(NULL) - # set some parameters, resolving the way they were passed in with other parameters # in `params`. # this ensures that the model stored with Booster$save() correctly represents @@ -125,6 +119,16 @@ lgb.cv <- function(params = list() , params = params , alternative_kwarg_value = nrounds ) + params <- lgb.check.wrapper_param( + main_param_name = "metric" + , params = params + , alternative_kwarg_value = NULL + ) + params <- lgb.check.wrapper_param( + main_param_name = "objective" + , params = params + , alternative_kwarg_value = NULL + ) params <- lgb.check.wrapper_param( main_param_name = "early_stopping_round" , params = params @@ -132,7 +136,9 @@ lgb.cv <- function(params = list() ) early_stopping_rounds <- params[["early_stopping_round"]] - # Check for objective (function or not) + # extract any function objects passed for objective or metric + params <- lgb.check.obj(params = params, obj = obj) + fobj <- NULL if (is.function(params$objective)) { fobj <- params$objective params$objective <- "NONE" @@ -142,6 +148,8 @@ lgb.cv <- function(params = list() # (for backwards compatibility). If it is a list of functions, store # all of them. This makes it possible to pass any mix of strings like "auc" # and custom functions to eval + params <- lgb.check.eval(params = params, eval = eval) + eval_functions <- list(NULL) if (is.function(eval)) { eval_functions <- list(eval) } diff --git a/R-package/R/lgb.train.R b/R-package/R/lgb.train.R index 1fd40d596d06..fbfc7e5c6565 100644 --- a/R-package/R/lgb.train.R +++ b/R-package/R/lgb.train.R @@ -73,12 +73,6 @@ lgb.train <- function(params = list(), } } - # Setup temporary variables - params <- lgb.check.obj(params = params, obj = obj) - params <- lgb.check.eval(params = params, eval = eval) - fobj <- NULL - eval_functions <- list(NULL) - # set some parameters, resolving the way they were passed in with other parameters # in `params`. # this ensures that the model stored with Booster$save() correctly represents @@ -93,6 +87,16 @@ lgb.train <- function(params = list(), , params = params , alternative_kwarg_value = nrounds ) + params <- lgb.check.wrapper_param( + main_param_name = "metric" + , params = params + , alternative_kwarg_value = NULL + ) + params <- lgb.check.wrapper_param( + main_param_name = "objective" + , params = params + , alternative_kwarg_value = NULL + ) params <- lgb.check.wrapper_param( main_param_name = "early_stopping_round" , params = params @@ -100,7 +104,9 @@ lgb.train <- function(params = list(), ) early_stopping_rounds <- params[["early_stopping_round"]] - # Check for objective (function or not) + # extract any function objects passed for objective or metric + params <- lgb.check.obj(params = params, obj = obj) + fobj <- NULL if (is.function(params$objective)) { fobj <- params$objective params$objective <- "NONE" @@ -110,6 +116,8 @@ lgb.train <- function(params = list(), # (for backwards compatibility). If it is a list of functions, store # all of them. This makes it possible to pass any mix of strings like "auc" # and custom functions to eval + params <- lgb.check.eval(params = params, eval = eval) + eval_functions <- list(NULL) if (is.function(eval)) { eval_functions <- list(eval) } diff --git a/R-package/tests/testthat/test_basic.R b/R-package/tests/testthat/test_basic.R index 464412c6425c..0a46a364054a 100644 --- a/R-package/tests/testthat/test_basic.R +++ b/R-package/tests/testthat/test_basic.R @@ -547,6 +547,52 @@ test_that("lgb.cv() respects showsd argument", { expect_identical(evals_no_showsd[["eval_err"]], list()) }) +test_that("lgb.cv() respects parameter aliases for objective", { + nrounds <- 3L + nfold <- 4L + dtrain <- lgb.Dataset( + data = train$data + , label = train$label + ) + cv_bst <- lgb.cv( + data = dtrain + , params = list( + num_leaves = 5L + , application = "binary" + , num_iterations = nrounds + ) + , nfold = nfold + ) + expect_equal(cv_bst$best_iter, nrounds) + expect_named(cv_bst$record_evals[["valid"]], "binary_logloss") + expect_length(cv_bst$record_evals[["valid"]][["binary_logloss"]][["eval"]], nrounds) + expect_length(cv_bst$boosters, nfold) +}) + +test_that("lgb.cv() respects parameter aliases for metric", { + nrounds <- 3L + nfold <- 4L + dtrain <- lgb.Dataset( + data = train$data + , label = train$label + ) + cv_bst <- lgb.cv( + data = dtrain + , params = list( + num_leaves = 5L + , objective = "binary" + , num_iterations = nrounds + , metric_types = c("auc", "binary_logloss") + ) + , nfold = nfold + ) + expect_equal(cv_bst$best_iter, nrounds) + expect_named(cv_bst$record_evals[["valid"]], c("auc", "binary_logloss")) + expect_length(cv_bst$record_evals[["valid"]][["binary_logloss"]][["eval"]], nrounds) + expect_length(cv_bst$record_evals[["valid"]][["auc"]][["eval"]], nrounds) + expect_length(cv_bst$boosters, nfold) +}) + test_that("lgb.cv() respects eval_train_metric argument", { dtrain <- lgb.Dataset(train$data, label = train$label) params <- list( @@ -616,6 +662,53 @@ test_that("lgb.train() works as expected with multiple eval metrics", { ) }) +test_that("lgb.train() respects parameter aliases for objective", { + nrounds <- 3L + dtrain <- lgb.Dataset( + data = train$data + , label = train$label + ) + bst <- lgb.train( + data = dtrain + , params = list( + num_leaves = 5L + , application = "binary" + , num_iterations = nrounds + ) + , valids = list( + "the_training_data" = dtrain + ) + ) + expect_named(bst$record_evals[["the_training_data"]], "binary_logloss") + expect_length(bst$record_evals[["the_training_data"]][["binary_logloss"]][["eval"]], nrounds) + expect_equal(bst$params[["objective"]], "binary") +}) + +test_that("lgb.train() respects parameter aliases for metric", { + nrounds <- 3L + dtrain <- lgb.Dataset( + data = train$data + , label = train$label + ) + bst <- lgb.train( + data = dtrain + , params = list( + num_leaves = 5L + , objective = "binary" + , num_iterations = nrounds + , metric_types = c("auc", "binary_logloss") + ) + , valids = list( + "train" = dtrain + ) + ) + record_results <- bst$record_evals[["train"]] + expect_equal(sort(names(record_results)), c("auc", "binary_logloss")) + expect_length(record_results[["auc"]][["eval"]], nrounds) + expect_length(record_results[["binary_logloss"]][["eval"]], nrounds) + expect_equal(bst$params[["metric"]], list("auc", "binary_logloss")) +}) + test_that("lgb.train() rejects negative or 0 value passed to nrounds", { dtrain <- lgb.Dataset(train$data, label = train$label) params <- list(