Skip to content

Commit

Permalink
reproducible parameter alias resolution for wrappers (fixes #5304) (#…
Browse files Browse the repository at this point in the history
…5338)

* dump sorted parameter aliases

* update lgb.check.wrapper_param

* update _choose_param_value to look like lgb.check.wrapper_param

* apply suggestions from review

* reduce diff

* move DumpAliases to config

* remove unnecessary check

* restore parameter check
  • Loading branch information
jmoralez authored Jul 30, 2022
1 parent 212d145 commit 83627ff
Show file tree
Hide file tree
Showing 9 changed files with 219 additions and 170 deletions.
6 changes: 3 additions & 3 deletions R-package/R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -177,15 +177,15 @@ lgb.check.eval <- function(params, eval) {
# ways, the first item in this list is used:
#
# 1. the main (non-alias) parameter found in `params`
# 2. the first alias of that parameter found in `params`
# 2. the alias with the highest priority found in `params`
# 3. the keyword argument passed in
#
# For example, "num_iterations" can also be provided to lgb.train()
# via keyword "nrounds". lgb.train() will choose one value for this parameter
# based on the first match in this list:
#
# 1. params[["num_iterations]]
# 2. the first alias of "num_iterations" found in params
# 2. the highest priority alias of "num_iterations" found in params
# 3. the nrounds keyword argument
#
# If multiple aliases are found in `params` for the same parameter, they are
Expand All @@ -197,7 +197,7 @@ lgb.check.eval <- function(params, eval) {
lgb.check.wrapper_param <- function(main_param_name, params, alternative_kwarg_value) {

aliases <- .PARAMETER_ALIASES()[[main_param_name]]
aliases_provided <- names(params)[names(params) %in% aliases]
aliases_provided <- aliases[aliases %in% names(params)]
aliases_provided <- aliases_provided[aliases_provided != main_param_name]

# prefer the main parameter
Expand Down
1 change: 1 addition & 0 deletions R-package/tests/testthat/test_parameters.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,7 @@ test_that(".PARAMETER_ALIASES() returns a named list of character vectors, where
expect_true(all(sapply(param_aliases, is.character)))
expect_true(length(unique(names(param_aliases))) == length(param_aliases))
expect_equal(sort(param_aliases[["task"]]), c("task", "task_type"))
expect_equal(param_aliases[["bagging_fraction"]], c("bagging_fraction", "bagging", "sub_row", "subsample"))
})

test_that(".PARAMETER_ALIASES() uses the internal session cache", {
Expand Down
6 changes: 3 additions & 3 deletions R-package/tests/testthat/test_utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ test_that("lgb.check.wrapper_param() prefers alias to keyword arg", {
expect_equal(params[["num_iterations"]], num_tree)
expect_identical(params, list(num_iterations = num_tree))

# switching the order should switch which one is chosen
# switching the order shouldn't switch which one is chosen
params2 <- lgb.check.wrapper_param(
main_param_name = "num_iterations"
, params = list(
Expand All @@ -132,6 +132,6 @@ test_that("lgb.check.wrapper_param() prefers alias to keyword arg", {
)
, alternative_kwarg_value = kwarg_val
)
expect_equal(params2[["num_iterations"]], n_estimators)
expect_identical(params2, list(num_iterations = n_estimators))
expect_equal(params2[["num_iterations"]], num_tree)
expect_identical(params2, list(num_iterations = num_tree))
})
26 changes: 13 additions & 13 deletions helpers/parameter_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,20 +359,20 @@ def gen_parameter_code(
str_to_write += " return str_buf.str();\n"
str_to_write += "}\n\n"

str_to_write += "const std::string Config::DumpAliases() {\n"
str_to_write += " std::stringstream str_buf;\n"
str_to_write += ' str_buf << "{";\n'
for idx, name in enumerate(names):
if idx > 0:
str_to_write += ', ";\n'
aliases = '\\", \\"'.join([alias for alias in names_with_aliases[name]])
aliases = f'[\\"{aliases}\\"]' if aliases else '[]'
str_to_write += f' str_buf << "\\"{name}\\": {aliases}'
str_to_write += '";\n'
str_to_write += ' str_buf << "}";\n'
str_to_write += " return str_buf.str();\n"
str_to_write += "}\n\n"
str_to_write += """const std::unordered_map<std::string, std::vector<std::string>>& Config::parameter2aliases() {
static std::unordered_map<std::string, std::vector<std::string>> map({"""
for name in names:
str_to_write += '\n {"' + name + '", '
if names_with_aliases[name]:
str_to_write += '{"' + '", "'.join(names_with_aliases[name]) + '"}},'
else:
str_to_write += '{}},'
str_to_write += """
});
return map;
}
"""
str_to_write += "} // namespace LightGBM\n"
with open(config_out_cpp, "w") as config_out_cpp_file:
config_out_cpp_file.write(str_to_write)
Expand Down
17 changes: 14 additions & 3 deletions include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,14 @@ struct Config {
const std::unordered_map<std::string, std::string>& params,
const std::string& name, bool* out);

/*!
* \brief Sort aliases by length and then alphabetically
* \param x Alias 1
* \param y Alias 2
* \return true if x has higher priority than y
*/
inline static bool SortAlias(const std::string& x, const std::string& y);

static void KV2Map(std::unordered_map<std::string, std::string>* params, const char* kv);
static std::unordered_map<std::string, std::string> Str2Map(const char* parameters);

Expand Down Expand Up @@ -1063,6 +1071,7 @@ struct Config {
bool is_data_based_parallel = false;
LIGHTGBM_EXPORT void Set(const std::unordered_map<std::string, std::string>& params);
static const std::unordered_map<std::string, std::string>& alias_table();
static const std::unordered_map<std::string, std::vector<std::string>>& parameter2aliases();
static const std::unordered_set<std::string>& parameter_set();
std::vector<std::vector<double>> auc_mu_weights_matrix;
std::vector<std::vector<int>> interaction_constraints_vector;
Expand Down Expand Up @@ -1131,6 +1140,10 @@ inline bool Config::GetBool(
return false;
}

inline bool Config::SortAlias(const std::string& x, const std::string& y) {
return x.size() < y.size() || (x.size() == y.size() && x < y);
}

struct ParameterAlias {
static void KeyAliasTransform(std::unordered_map<std::string, std::string>* params) {
std::unordered_map<std::string, std::string> tmp_map;
Expand All @@ -1139,9 +1152,7 @@ struct ParameterAlias {
if (alias != Config::alias_table().end()) { // found alias
auto alias_set = tmp_map.find(alias->second);
if (alias_set != tmp_map.end()) { // alias already set
// set priority by length & alphabetically to ensure reproducible behavior
if (alias_set->second.size() < pair.first.size() ||
(alias_set->second.size() == pair.first.size() && alias_set->second < pair.first)) {
if (Config::SortAlias(alias_set->second, pair.first)) {
Log::Warning("%s is set with %s=%s, %s=%s will be ignored. Current value: %s=%s",
alias->second.c_str(), alias_set->second.c_str(), params->at(alias_set->second).c_str(),
pair.first.c_str(), pair.second.c_str(), alias->second.c_str(), params->at(alias_set->second).c_str());
Expand Down
17 changes: 12 additions & 5 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ class _ConfigAliases:
aliases = None

@staticmethod
def _get_all_param_aliases() -> Dict[str, Set[str]]:
def _get_all_param_aliases() -> Dict[str, List[str]]:
buffer_len = 1 << 20
tmp_out_len = ctypes.c_int64(0)
string_buffer = ctypes.create_string_buffer(buffer_len)
Expand All @@ -365,7 +365,7 @@ def _get_all_param_aliases() -> Dict[str, Set[str]]:
ptr_string_buffer))
aliases = json.loads(
string_buffer.value.decode('utf-8'),
object_hook=lambda obj: {k: set(v) | {k} for k, v in obj.items()}
object_hook=lambda obj: {k: [k] + v for k, v in obj.items()}
)
return aliases

Expand All @@ -375,9 +375,15 @@ def get(cls, *args) -> Set[str]:
cls.aliases = cls._get_all_param_aliases()
ret = set()
for i in args:
ret |= cls.aliases.get(i, {i})
ret.update(cls.get_sorted(i))
return ret

@classmethod
def get_sorted(cls, name: str) -> List[str]:
if cls.aliases is None:
cls.aliases = cls._get_all_param_aliases()
return cls.aliases.get(name, [name])

@classmethod
def get_by_alias(cls, *args) -> Set[str]:
if cls.aliases is None:
Expand All @@ -386,7 +392,7 @@ def get_by_alias(cls, *args) -> Set[str]:
for arg in args:
for aliases in cls.aliases.values():
if arg in aliases:
ret |= aliases
ret.update(aliases)
break
return ret

Expand All @@ -412,7 +418,8 @@ def _choose_param_value(main_param_name: str, params: Dict[str, Any], default_va
# avoid side effects on passed-in parameters
params = deepcopy(params)

aliases = _ConfigAliases.get(main_param_name) - {main_param_name}
aliases = _ConfigAliases.get_sorted(main_param_name)
aliases = [a for a in aliases if a != main_param_name]

# if main_param_name was provided, keep that value and remove all aliases
if main_param_name in params.keys():
Expand Down
25 changes: 25 additions & 0 deletions src/io/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -411,4 +411,29 @@ std::string Config::ToString() const {
return str_buf.str();
}

const std::string Config::DumpAliases() {
auto map = Config::parameter2aliases();
for (auto& pair : map) {
std::sort(pair.second.begin(), pair.second.end(), SortAlias);
}
std::stringstream str_buf;
str_buf << "{\n";
bool first = true;
for (const auto& pair : map) {
if (first) {
str_buf << " \"";
first = false;
} else {
str_buf << " , \"";
}
str_buf << pair.first << "\": [";
if (pair.second.size() > 0) {
str_buf << "\"" << CommonC::Join(pair.second, "\", \"") << "\"";
}
str_buf << "]\n";
}
str_buf << "}\n";
return str_buf.str();
}

} // namespace LightGBM
Loading

0 comments on commit 83627ff

Please sign in to comment.