Skip to content

Commit

Permalink
output warning when only set max_depth
Browse files Browse the repository at this point in the history
  • Loading branch information
guolinke committed Aug 16, 2017
1 parent f8d6db9 commit 50534f1
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 1 deletion.
3 changes: 2 additions & 1 deletion include/LightGBM/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const std::string kDefaultTreeLearnerType = "serial";
const std::string kDefaultDevice = "cpu";
const std::string kDefaultBoostingType = "gbdt";
const std::string kDefaultObjectiveType = "regression";
const int kDefaultNumLeaves = 31;

/*!
* \brief The interface for Config
Expand Down Expand Up @@ -202,7 +203,7 @@ struct TreeConfig: public ConfigBase {
double lambda_l2 = 0.0f;
double min_gain_to_split = 0.0f;
// should > 1
int num_leaves = 31;
int num_leaves = kDefaultNumLeaves;
int feature_fraction_seed = 2;
double feature_fraction = 1.0f;
// max cache size(unit:MB) for historical histogram. < 0 means no limit
Expand Down
9 changes: 9 additions & 0 deletions src/io/config.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,14 @@ void OverallConfig::CheckParamConflict() {
boosting_config.tree_config.histogram_pool_size = -1;
}
}
// Check max_depth and num_leaves
if (boosting_config.tree_config.max_depth > 0) {
int full_num_leaves = std::pow(2, boosting_config.tree_config.max_depth);
if (full_num_leaves > boosting_config.tree_config.num_leaves
&& boosting_config.tree_config.num_leaves == kDefaultNumLeaves) {
Log::Warning("Accuarcy may be bad since you didn't set num_leaves.");
}
}
}

void IOConfig::Set(const std::unordered_map<std::string, std::string>& params) {
Expand Down Expand Up @@ -370,6 +378,7 @@ void TreeConfig::Set(const std::unordered_map<std::string, std::string>& params)
CHECK(feature_fraction > 0.0f && feature_fraction <= 1.0f);
GetDouble(params, "histogram_pool_size", &histogram_pool_size);
GetInt(params, "max_depth", &max_depth);
CHECK(max_depth > 0);
GetInt(params, "top_k", &top_k);
GetInt(params, "gpu_platform_id", &gpu_platform_id);
GetInt(params, "gpu_device_id", &gpu_device_id);
Expand Down

0 comments on commit 50534f1

Please sign in to comment.