diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 22f8807ac3e8..9d4b703c25b9 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -678,6 +678,17 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetFeatureNames(BoosterHandle handle, size_t* out_buffer_len, char** out_strs); +/*! + * \brief Check that the feature names of the data match the ones used to train the booster. + * \param handle Handle of booster + * \param data_names Array with the feature names in the data + * \param data_num_features Number of features in the data + * \return 0 when succeed, -1 when failure happens + */ +LIGHTGBM_C_EXPORT int LGBM_BoosterValidateFeatureNames(BoosterHandle handle, + const char** data_names, + int data_num_features); + /*! * \brief Get number of features. * \param handle Handle of booster diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 0d494636a47a..7f0e7e9d9bc8 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -757,7 +757,7 @@ def __getstate__(self): return this def predict(self, data, start_iteration=0, num_iteration=-1, - raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False): + raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False, validate_features=False): """Predict logic. Parameters @@ -778,6 +778,9 @@ def predict(self, data, start_iteration=0, num_iteration=-1, data_has_header : bool, optional (default=False) Whether data has header. Used only for txt data. + validate_features : bool, optional (default=False) + If True, ensure that the features used to predict match the ones used to train. + Used only if data is pandas DataFrame. Returns ------- @@ -787,6 +790,17 @@ def predict(self, data, start_iteration=0, num_iteration=-1, """ if isinstance(data, Dataset): raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead") + elif isinstance(data, pd_DataFrame) and validate_features: + data_names = [str(x) for x in data.columns] + ptr_names = (ctypes.c_char_p * len(data_names))() + ptr_names[:] = [x.encode('utf-8') for x in data_names] + _safe_call( + _LIB.LGBM_BoosterValidateFeatureNames( + self.handle, + ptr_names, + ctypes.c_int(len(data_names)), + ) + ) data = _data_from_pandas(data, None, None, self.pandas_categorical)[0] predict_type = C_API_PREDICT_NORMAL if raw_score: @@ -3501,7 +3515,8 @@ def dump_model(self, num_iteration=None, start_iteration=0, importance_type='spl def predict(self, data, start_iteration=0, num_iteration=None, raw_score=False, pred_leaf=False, pred_contrib=False, - data_has_header=False, **kwargs): + data_has_header=False, validate_features=False, + **kwargs): """Make a prediction. Parameters @@ -3535,6 +3550,9 @@ def predict(self, data, start_iteration=0, num_iteration=None, data_has_header : bool, optional (default=False) Whether the data has header. Used only if data is str. + validate_features : bool, optional (default=False) + If True, ensure that the features used to predict match the ones used to train. + Used only if data is pandas DataFrame. **kwargs Other parameters for the prediction. @@ -3552,7 +3570,7 @@ def predict(self, data, start_iteration=0, num_iteration=None, num_iteration = -1 return predictor.predict(data, start_iteration, num_iteration, raw_score, pred_leaf, pred_contrib, - data_has_header) + data_has_header, validate_features) def refit( self, diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index ec22cd5a7476..682446455838 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -325,6 +325,9 @@ def __call__(self, preds: np.ndarray, dataset: Dataset) -> Tuple[str, float, boo Note that unlike the shap package, with ``pred_contrib`` we return a matrix with an extra column, where the last column is the expected value. + validate_features : bool, optional (default=False) + If True, ensure that the features used to predict match the ones used to train. + Used only if data is pandas DataFrame. **kwargs Other parameters for the prediction. @@ -820,7 +823,7 @@ def _get_meta_data(collection, name, i): ) + "\n\n" + _lgbmmodel_doc_custom_eval_note def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None, - pred_leaf=False, pred_contrib=False, **kwargs): + pred_leaf=False, pred_contrib=False, validate_features=False, **kwargs): """Docstring is set after definition, using a template.""" if not self.__sklearn_is_fitted__(): raise LGBMNotFittedError("Estimator not fitted, call fit before exploiting the model.") @@ -853,7 +856,8 @@ def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None, predict_params["num_threads"] = self._process_n_jobs(predict_params["num_threads"]) return self._Booster.predict(X, raw_score=raw_score, start_iteration=start_iteration, num_iteration=num_iteration, - pred_leaf=pred_leaf, pred_contrib=pred_contrib, **predict_params) + pred_leaf=pred_leaf, pred_contrib=pred_contrib, validate_features=validate_features, + **predict_params) predict.__doc__ = _lgbmmodel_doc_predict.format( description="Return the predicted value for each sample.", @@ -1087,10 +1091,12 @@ def fit( + _base_doc[_base_doc.find('eval_metric :'):]) def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None, - pred_leaf=False, pred_contrib=False, **kwargs): + pred_leaf=False, pred_contrib=False, validate_features=False, + **kwargs): """Docstring is inherited from the LGBMModel.""" result = self.predict_proba(X, raw_score, start_iteration, num_iteration, - pred_leaf, pred_contrib, **kwargs) + pred_leaf, pred_contrib, validate_features, + **kwargs) if callable(self._objective) or raw_score or pred_leaf or pred_contrib: return result else: @@ -1100,9 +1106,9 @@ def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None, predict.__doc__ = LGBMModel.predict.__doc__ def predict_proba(self, X, raw_score=False, start_iteration=0, num_iteration=None, - pred_leaf=False, pred_contrib=False, **kwargs): + pred_leaf=False, pred_contrib=False, validate_features=False, **kwargs): """Docstring is set after definition, using a template.""" - result = super().predict(X, raw_score, start_iteration, num_iteration, pred_leaf, pred_contrib, **kwargs) + result = super().predict(X, raw_score, start_iteration, num_iteration, pred_leaf, pred_contrib, validate_features, **kwargs) if callable(self._objective) and not (raw_score or pred_leaf or pred_contrib): _log_warning("Cannot compute class probabilities or labels " "due to the usage of customized objective function.\n" diff --git a/src/c_api.cpp b/src/c_api.cpp index 75fb2e4bd249..ebac89fe71e0 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -2129,6 +2129,27 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle, API_END(); } +int LGBM_BoosterValidateFeatureNames(BoosterHandle handle, + const char** data_names, + int data_num_features) { + API_BEGIN(); + int booster_num_features; + size_t out_buffer_len; + LGBM_BoosterGetFeatureNames(handle, 0, &booster_num_features, 0, &out_buffer_len, nullptr); + if (booster_num_features != data_num_features) { + Log::Fatal("Model was trained on %d features, but got %d input features to predict.", booster_num_features, data_num_features); + } + std::vector> tmp_names(booster_num_features, std::vector(out_buffer_len)); + std::vector booster_names = Vector2Ptr(&tmp_names); + LGBM_BoosterGetFeatureNames(handle, data_num_features, &booster_num_features, out_buffer_len, &out_buffer_len, booster_names.data()); + for (int i = 0; i < booster_num_features; ++i) { + if (strcmp(data_names[i], booster_names[i]) != 0) { + Log::Fatal("Expected '%s' at position %d but found '%s'", booster_names[i], i, data_names[i]); + } + } + API_END(); +} + int LGBM_BoosterPredictForMat(BoosterHandle handle, const void* data, int data_type, diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 0aa27349b677..e53bb6b0e594 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -18,6 +18,7 @@ from sklearn.model_selection import GroupKFold, TimeSeriesSplit, train_test_split import lightgbm as lgb +from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame from .utils import (dummy_obj, load_boston, load_breast_cancer, load_digits, load_iris, logistic_sigmoid, make_synthetic_regression, mse_obj, sklearn_multiclass_custom_objective, softmax) @@ -3623,3 +3624,21 @@ def test_cegb_split_buffer_clean(): predicts = model.predict(test_data) rmse = np.sqrt(mean_squared_error(test_y, predicts)) assert rmse < 10.0 + + +@pytest.mark.skipif(not PANDAS_INSTALLED, reason='pandas is not installed') +def test_validate_features(): + X, y = make_synthetic_regression() + features = ['x1', 'x2', 'x3', 'x4'] + df = pd_DataFrame(X, columns=features) + ds = lgb.Dataset(df, y) + bst = lgb.train({'num_leaves': 15, 'verbose': -1}, ds, num_boost_round=10) + assert bst.feature_name() == features + + # try to predict with a different feature + df2 = df.rename(columns={'x3': 'z'}) + with pytest.raises(lgb.basic.LightGBMError, match="Expected 'x3' at position 2 but found 'z'"): + bst.predict(df2, validate_features=True) + + # check that disabling the check doesn't raise the error + bst.predict(df2, validate_features=False) diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index fd513537a486..2fdd31c23be1 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -18,11 +18,30 @@ from sklearn.utils.validation import check_is_fitted import lightgbm as lgb +from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame from .utils import (load_boston, load_breast_cancer, load_digits, load_iris, load_linnerud, make_ranking, make_synthetic_regression, sklearn_multiclass_custom_objective, softmax) decreasing_generator = itertools.count(0, -1) +task_to_model_factory = { + 'ranking': lgb.LGBMRanker, + 'classification': lgb.LGBMClassifier, + 'regression': lgb.LGBMRegressor, +} + + +def _create_data(task): + if task == 'ranking': + X, y, g = make_ranking(n_features=4) + g = np.bincount(g) + elif task == 'classification': + X, y = load_iris(return_X_y=True) + g = None + elif task == 'regression': + X, y = make_synthetic_regression() + g = None + return X, y, g class UnpicklableCallback: @@ -1244,16 +1263,7 @@ def test_sklearn_integration(estimator, check): @pytest.mark.parametrize('task', ['classification', 'ranking', 'regression']) def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(task): pd = pytest.importorskip("pandas") - if task == 'ranking': - X, y, g = make_ranking() - g = np.bincount(g) - model_factory = lgb.LGBMRanker - elif task == 'classification': - X, y = load_iris(return_X_y=True) - model_factory = lgb.LGBMClassifier - elif task == 'regression': - X, y = make_synthetic_regression() - model_factory = lgb.LGBMRegressor + X, y, g = _create_data(task) X = pd.DataFrame(X) y_col_array = y.reshape(-1, 1) params = { @@ -1261,6 +1271,7 @@ def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(task 'num_leaves': 3, 'random_state': 0 } + model_factory = task_to_model_factory[task] with pytest.warns(UserWarning, match='column-vector'): if task == 'ranking': model_1d = model_factory(**params).fit(X, y, group=g) @@ -1315,3 +1326,25 @@ def test_default_n_jobs(tmp_path): with open(tmp_path / "model.txt", "r") as f: model_txt = f.read() assert bool(re.search(rf"\[num_threads: {n_cores}\]", model_txt)) + + +@pytest.mark.skipif(not PANDAS_INSTALLED, reason='pandas is not installed') +@pytest.mark.parametrize('task', ['classification', 'ranking', 'regression']) +def test_validate_features(task): + X, y, g = _create_data(task) + features = ['x1', 'x2', 'x3', 'x4'] + df = pd_DataFrame(X, columns=features) + model = task_to_model_factory[task](n_estimators=10, num_leaves=15, verbose=-1) + if task == 'ranking': + model.fit(df, y, group=g) + else: + model.fit(df, y) + assert model.feature_name_ == features + + # try to predict with a different feature + df2 = df.rename(columns={'x2': 'z'}) + with pytest.raises(lgb.basic.LightGBMError, match="Expected 'x2' at position 1 but found 'z'"): + model.predict(df2, validate_features=True) + + # check that disabling the check doesn't raise the error + model.predict(df2, validate_features=False)