diff --git a/src/boosting/gbdt_model_text.cpp b/src/boosting/gbdt_model_text.cpp index 73c3ea98d3f6..695eae3c70cb 100644 --- a/src/boosting/gbdt_model_text.cpp +++ b/src/boosting/gbdt_model_text.cpp @@ -632,14 +632,24 @@ std::vector GBDT::FeatureImportance(int num_iteration, int importance_ty } std::vector feature_importances(max_feature_idx_ + 1, 0.0); + bool warn_about_feature_number = false; + int max_feature_index_found = -1; if (importance_type == 0) { for (int iter = 0; iter < num_used_model; ++iter) { for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) { if (models_[iter]->split_gain(split_idx) > 0) { + const int real_feature_index = models_[iter]->split_feature(split_idx); #ifdef DEBUG - CHECK_GE(models_[iter]->split_feature(split_idx), 0); + CHECK_GE(real_feature_index, 0); #endif - feature_importances[models_[iter]->split_feature(split_idx)] += 1.0; + if (static_cast(real_feature_index) >= feature_importances.size()) { + warn_about_feature_number = true; + if (real_feature_index > max_feature_index_found) { + max_feature_index_found = real_feature_index; + } + } else { + feature_importances[real_feature_index] += 1.0; + } } } } @@ -647,16 +657,29 @@ std::vector GBDT::FeatureImportance(int num_iteration, int importance_ty for (int iter = 0; iter < num_used_model; ++iter) { for (int split_idx = 0; split_idx < models_[iter]->num_leaves() - 1; ++split_idx) { if (models_[iter]->split_gain(split_idx) > 0) { + const int real_feature_index = models_[iter]->split_feature(split_idx); #ifdef DEBUG - CHECK_GE(models_[iter]->split_feature(split_idx), 0); + CHECK_GE(real_feature_index, 0); #endif - feature_importances[models_[iter]->split_feature(split_idx)] += models_[iter]->split_gain(split_idx); + if (static_cast(real_feature_index) >= feature_importances.size()) { + warn_about_feature_number = true; + if (real_feature_index > max_feature_index_found) { + max_feature_index_found = real_feature_index; + } + } else { + feature_importances[real_feature_index] += models_[iter]->split_gain(split_idx); + } } } } } else { Log::Fatal("Unknown importance type: only support split=0 and gain=1"); } + if (warn_about_feature_number) { + Log::Warning("Only %d features found in dataset for continual training, but at least %d features found in initial model.", + static_cast(feature_importances.size()), max_feature_index_found + 1); + Log::Warning("Please check the number of features used in continual training."); + } return feature_importances; } diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 4691120184aa..397c41cdfe87 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -1020,6 +1020,29 @@ def test_continue_train_multiclass(): assert evals_result['valid_0']['multi_logloss'][-1] == pytest.approx(ret) +def test_continue_train_different_feature_size(capsys): + np.random.seed(0) + train_X = np.hstack([np.ones(800).reshape(-1, 8), np.arange(200, 0, -1).reshape(-1, 2)]) + train_y = np.sum(train_X[:, -2:], axis=1) + train_data = lgb.Dataset(train_X, label=train_y) + params = { + "objective": "regression", + "num_trees": 10, + "num_leaves": 31, + "verbose": -1, + 'predict_disable_shape_check': True, + } + model = lgb.train(train_set=train_data, params=params) + + train_X_cont = np.random.rand(100, 5) + train_y_cont = np.sum(train_X_cont, axis=1) + train_data_cont = lgb.Dataset(train_X_cont, label=train_y_cont) + params.update({"verbose": 2}) + lgb.train(train_set=train_data_cont, params=params, init_model=model) + captured = capsys.readouterr() + assert captured.out.find("features found in dataset for continual training, but at least") != -1 + + def test_cv(): X_train, y_train = make_synthetic_regression() params = {'verbose': -1}