Skip to content

Commit

Permalink
[R-package] reduce cost of repeated parameter alias checks (#5141)
Browse files Browse the repository at this point in the history
* [R-package] reduce cost of repeated parameter alias checks

* formatting
  • Loading branch information
jameslamb authored Apr 13, 2022
1 parent b462d0a commit 204fa9c
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 0 deletions.
13 changes: 13 additions & 0 deletions R-package/R/aliases.R
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,18 @@
)])
}

# [description] Non-exported environment, used for caching details that only need to be
# computed once per R session.
.lgb_session_cache_env <- new.env()

# [description] List of respected parameter aliases. Wrapped in a function to take advantage of
# lazy evaluation (so it doesn't matter what order R sources files during installation).
# [return] A named list, where each key is a main LightGBM parameter and each value is a character
# vector of corresponding aliases.
.PARAMETER_ALIASES <- function() {
if (exists("PARAMETER_ALIASES", where = .lgb_session_cache_env)) {
return(get("PARAMETER_ALIASES", envir = .lgb_session_cache_env))
}
params_to_aliases <- jsonlite::fromJSON(
.Call(
LGBM_DumpParamAliases_R
Expand All @@ -47,6 +54,12 @@
aliases_with_main_name <- c(main_name, unlist(params_to_aliases[[main_name]]))
params_to_aliases[[main_name]] <- aliases_with_main_name
}
# store in cache so the next call to `.PARAMETER_ALIASES()` doesn't need to recompute this
assign(
x = "PARAMETER_ALIASES"
, value = params_to_aliases
, envir = .lgb_session_cache_env
)
return(params_to_aliases)
}

Expand Down
31 changes: 31 additions & 0 deletions R-package/tests/testthat/test_parameters.R
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,37 @@ test_that(".PARAMETER_ALIASES() returns a named list of character vectors, where
expect_equal(sort(param_aliases[["task"]]), c("task", "task_type"))
})

test_that(".PARAMETER_ALIASES() uses the internal session cache", {

cache_key <- "PARAMETER_ALIASES"

# clear cache, so this test isn't reliant on the order unit tests are run in
if (exists(cache_key, where = .lgb_session_cache_env)) {
rm(list = cache_key, envir = .lgb_session_cache_env)
}
expect_false(exists(cache_key, where = .lgb_session_cache_env))

# check that result looks correct for at least one parameter
iter_aliases <- .PARAMETER_ALIASES()[["num_iterations"]]
expect_true(is.character(iter_aliases))
expect_true(all(c("num_round", "nrounds") %in% iter_aliases))

# patch the cache to check that .PARAMETER_ALIASES() checks it
assign(
x = cache_key
, value = list(num_iterations = c("test", "other_test"))
, envir = .lgb_session_cache_env
)
iter_aliases <- .PARAMETER_ALIASES()[["num_iterations"]]
expect_equal(iter_aliases, c("test", "other_test"))

# re-set cache so this doesn't interfere with other unit tests
if (exists(cache_key, where = .lgb_session_cache_env)) {
rm(list = cache_key, envir = .lgb_session_cache_env)
}
expect_false(exists(cache_key, where = .lgb_session_cache_env))
})

test_that("training should warn if you use 'dart' boosting, specified with 'boosting' or aliases", {
for (boosting_param in .PARAMETER_ALIASES()[["boosting"]]) {
params <- list(
Expand Down

0 comments on commit 204fa9c

Please sign in to comment.