diff --git a/tests/testthat/test-lightgbm.R b/tests/testthat/test-lightgbm.R index 6ea4bc9..c33c113 100644 --- a/tests/testthat/test-lightgbm.R +++ b/tests/testthat/test-lightgbm.R @@ -248,6 +248,40 @@ test_that("boost_tree with lightgbm",{ }) +test_that("bonsai correctly determines objective when label is a factor", { + skip_if_not_installed("lightgbm") + skip_if_not_installed("modeldata") + + suppressPackageStartupMessages({ + library(lightgbm) + library(dplyr) + }) + + data("penguins", package = "modeldata") + penguins <- penguins[complete.cases(penguins),] + + expect_error_free({ + bst <- train_lightgbm( + x = penguins[, c("bill_length_mm", "bill_depth_mm")], + y = penguins[["sex"]], + num_iterations = 5 + ) + }) + expect_equal(bst$params$objective, "binary") + expect_equal(bst$params$num_class, 1) + + expect_error_free({ + bst <- train_lightgbm( + x = penguins[, c("bill_length_mm", "bill_depth_mm")], + y = penguins[["species"]], + num_iterations = 5 + ) + }) + expect_equal(bst$params$objective, "multiclass") + expect_equal(bst$params$num_class, 3) +}) + + test_that("bonsai handles mtry vs mtry_prop gracefully", { skip_if_not_installed("modeldata")