From 84f5e1b8983712344c718a510a0f3997eff9dfd8 Mon Sep 17 00:00:00 2001 From: Jose Morales Date: Thu, 23 Dec 2021 19:20:34 -0600 Subject: [PATCH 01/19] check feature names and order in predict with dataframe --- python-package/lightgbm/basic.py | 13 +++++++++---- tests/python_package_test/test_basic.py | 26 +++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 4 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 64f1cb31edaa..75e786bac063 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -512,12 +512,17 @@ def _get_bad_pandas_dtypes(dtypes): return bad_indices -def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical): +def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical, is_predict=False): if isinstance(data, pd_DataFrame): if len(data.shape) != 2 or data.shape[0] < 1: raise ValueError('Input data must be 2 dimensional and non empty.') if feature_name == 'auto' or feature_name is None: data = data.rename(columns=str) + elif isinstance(feature_name, list) and is_predict: + missing_features = set(feature_name) - set(data.columns.astype(str)) + if missing_features: + raise ValueError(f'The following features are missing: {missing_features}') + data = data[feature_name] # ensure column order cat_cols = [col for col, dtype in zip(data.columns, data.dtypes) if isinstance(dtype, pd_CategoricalDtype)] cat_cols_not_ordered = [col for col in cat_cols if not data[col].cat.ordered] if pandas_categorical is None: # train dataset @@ -767,9 +772,6 @@ def predict(self, data, start_iteration=0, num_iteration=-1, Prediction result. Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``). """ - if isinstance(data, Dataset): - raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead") - data = _data_from_pandas(data, None, None, self.pandas_categorical)[0] predict_type = C_API_PREDICT_NORMAL if raw_score: predict_type = C_API_PREDICT_RAW_SCORE @@ -3493,6 +3495,9 @@ def predict(self, data, start_iteration=0, num_iteration=None, Prediction result. Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``). """ + if isinstance(data, Dataset): + raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead") + data = _data_from_pandas(data, self.feature_name(), None, self.pandas_categorical, is_predict=True)[0] predictor = self._to_predictor(deepcopy(kwargs)) if num_iteration is None: if start_iteration <= 0: diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 18a8403eba85..66bf7785d817 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -579,3 +579,29 @@ def test_param_aliases(): assert all(len(i) >= 1 for i in aliases.values()) assert all(k in v for k, v in aliases.items()) assert lgb.basic._ConfigAliases.get('config', 'task') == {'config', 'config_file', 'task', 'task_type'} + + +@pytest.mark.skipif(not PANDAS_INSTALLED, reason='pandas is not installed') +def test_predict_with_dataframe_checks_features(): + df = pd_DataFrame( + { + 'x1': np.random.rand(100), + 'x2': np.random.rand(100), + 'x3': np.random.rand(100), + 'y': np.random.rand(100), + } + ) + features = ['x1', 'x2', 'x3'] + ds = lgb.Dataset(df[features], df['y']) + bst = lgb.train({'num_leaves': 15, 'verbose': -1}, ds, num_boost_round=10) + assert bst.feature_name() == features + + # try to predict with a different feature + df2 = df.rename(columns={'x1': 'z'}) + with pytest.raises(ValueError, match="The following features are missing: {'x1'}"): + bst.predict(df2) + + # predict with the features out of order + preds_sorted_features = bst.predict(df[features]) + preds_reversed_features = bst.predict(df[features[::-1]]) + np.testing.assert_equal(preds_sorted_features, preds_reversed_features) From 0d1347502f2c6f55e5c853589158422279e881c3 Mon Sep 17 00:00:00 2001 From: Jose Morales Date: Thu, 23 Dec 2021 19:35:58 -0600 Subject: [PATCH 02/19] slice df in predict to remove the target --- tests/python_package_test/test_basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 66bf7785d817..f9a706115aa0 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -599,7 +599,7 @@ def test_predict_with_dataframe_checks_features(): # try to predict with a different feature df2 = df.rename(columns={'x1': 'z'}) with pytest.raises(ValueError, match="The following features are missing: {'x1'}"): - bst.predict(df2) + bst.predict(df2[['z', 'x2', 'x3']]) # predict with the features out of order preds_sorted_features = bst.predict(df[features]) From 1adc7f2cbbb4a8610635dcceae5d07dc9efdf095 Mon Sep 17 00:00:00 2001 From: Jose Morales Date: Thu, 23 Dec 2021 19:40:31 -0600 Subject: [PATCH 03/19] scramble features --- tests/python_package_test/test_basic.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index f9a706115aa0..c8d6107e2ce8 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -603,5 +603,5 @@ def test_predict_with_dataframe_checks_features(): # predict with the features out of order preds_sorted_features = bst.predict(df[features]) - preds_reversed_features = bst.predict(df[features[::-1]]) - np.testing.assert_equal(preds_sorted_features, preds_reversed_features) + preds_out_of_order_features = bst.predict(df[['x3', 'x1', 'x2']]) + np.testing.assert_equal(preds_sorted_features, preds_out_of_order_features) From a3cfcad143766ad5ab5879c6cdb80d6dfef0e855 Mon Sep 17 00:00:00 2001 From: Jose Morales Date: Thu, 23 Dec 2021 19:59:13 -0600 Subject: [PATCH 04/19] handle int column names --- python-package/lightgbm/basic.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 75e786bac063..b3774e4fd3c1 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -519,10 +519,12 @@ def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorica if feature_name == 'auto' or feature_name is None: data = data.rename(columns=str) elif isinstance(feature_name, list) and is_predict: - missing_features = set(feature_name) - set(data.columns.astype(str)) + features_df = data.columns.astype(str).tolist() + missing_features = set(feature_name) - set(features_df) if missing_features: raise ValueError(f'The following features are missing: {missing_features}') - data = data[feature_name] # ensure column order + sort_idxs = [features_df.index(feature) for feature in feature_name] + data = data.iloc[:, sort_idxs] # ensure column order cat_cols = [col for col, dtype in zip(data.columns, data.dtypes) if isinstance(dtype, pd_CategoricalDtype)] cat_cols_not_ordered = [col for col in cat_cols if not data[col].cat.ordered] if pandas_categorical is None: # train dataset From e0827d007a0e1ba72d0ae40f823b6fec973e7aae Mon Sep 17 00:00:00 2001 From: Jose Morales Date: Fri, 24 Dec 2021 10:26:27 -0600 Subject: [PATCH 05/19] only change column order when needed --- python-package/lightgbm/basic.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index b3774e4fd3c1..400c0a7e47e2 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -519,12 +519,13 @@ def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorica if feature_name == 'auto' or feature_name is None: data = data.rename(columns=str) elif isinstance(feature_name, list) and is_predict: - features_df = data.columns.astype(str).tolist() - missing_features = set(feature_name) - set(features_df) + df_features = [str(x) for x in data.columns] + missing_features = set(feature_name) - set(df_features) if missing_features: raise ValueError(f'The following features are missing: {missing_features}') - sort_idxs = [features_df.index(feature) for feature in feature_name] - data = data.iloc[:, sort_idxs] # ensure column order + sort_idxs = [df_features.index(feature) for feature in feature_name] + if not all(x == i for i, x in enumerate(sort_idxs)): + data = data.iloc[:, sort_idxs] # ensure column order cat_cols = [col for col, dtype in zip(data.columns, data.dtypes) if isinstance(dtype, pd_CategoricalDtype)] cat_cols_not_ordered = [col for col in cat_cols if not data[col].cat.ordered] if pandas_categorical is None: # train dataset From 8d814f50242862def3f8205ab7f5aada7209c571 Mon Sep 17 00:00:00 2001 From: Jose Morales Date: Wed, 29 Dec 2021 22:12:22 -0600 Subject: [PATCH 06/19] include validate_features param in booster and sklearn estimators --- examples/python-guide/advanced_example.py | 5 +- python-package/lightgbm/basic.py | 14 +++-- python-package/lightgbm/sklearn.py | 15 +++--- tests/python_package_test/test_basic.py | 30 ++++++----- tests/python_package_test/test_sklearn.py | 63 +++++++++++++++++++---- 5 files changed, 90 insertions(+), 37 deletions(-) diff --git a/examples/python-guide/advanced_example.py b/examples/python-guide/advanced_example.py index 54b62cdb1563..bd30e2c93f57 100644 --- a/examples/python-guide/advanced_example.py +++ b/examples/python-guide/advanced_example.py @@ -81,7 +81,8 @@ # load model to predict bst = lgb.Booster(model_file='model.txt') # can only predict with the best iteration (or the saving iteration) -y_pred = bst.predict(X_test) +# disable validating feature names since we know they're correct +y_pred = bst.predict(X_test, validate_features=False) # eval with loaded model rmse_loaded_model = mean_squared_error(y_test, y_pred) ** 0.5 print(f"The RMSE of loaded model's prediction is: {rmse_loaded_model}") @@ -94,7 +95,7 @@ with open('model.pkl', 'rb') as fin: pkl_bst = pickle.load(fin) # can predict with any iteration when loaded in pickle way -y_pred = pkl_bst.predict(X_test, num_iteration=7) +y_pred = pkl_bst.predict(X_test, num_iteration=7, validate_features=False) # eval with loaded model rmse_pickled_model = mean_squared_error(y_test, y_pred) ** 0.5 print(f"The RMSE of pickled model's prediction is: {rmse_pickled_model}") diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 400c0a7e47e2..a39aac87b330 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -512,17 +512,20 @@ def _get_bad_pandas_dtypes(dtypes): return bad_indices -def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical, is_predict=False): +def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical, validate_features=False): if isinstance(data, pd_DataFrame): if len(data.shape) != 2 or data.shape[0] < 1: raise ValueError('Input data must be 2 dimensional and non empty.') if feature_name == 'auto' or feature_name is None: data = data.rename(columns=str) - elif isinstance(feature_name, list) and is_predict: + elif isinstance(feature_name, list) and validate_features: df_features = [str(x) for x in data.columns] missing_features = set(feature_name) - set(df_features) if missing_features: - raise ValueError(f'The following features are missing: {missing_features}') + raise ValueError( + f"The following features are missing: {missing_features}.\n" + "If you're sure the features are correct you can disable this check by setting validate_features=False" + ) sort_idxs = [df_features.index(feature) for feature in feature_name] if not all(x == i for i, x in enumerate(sort_idxs)): data = data.iloc[:, sort_idxs] # ensure column order @@ -3453,7 +3456,8 @@ def dump_model(self, num_iteration=None, start_iteration=0, importance_type='spl def predict(self, data, start_iteration=0, num_iteration=None, raw_score=False, pred_leaf=False, pred_contrib=False, - data_has_header=False, is_reshape=True, **kwargs): + data_has_header=False, is_reshape=True, validate_features=True, + **kwargs): """Make a prediction. Parameters @@ -3500,7 +3504,7 @@ def predict(self, data, start_iteration=0, num_iteration=None, """ if isinstance(data, Dataset): raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead") - data = _data_from_pandas(data, self.feature_name(), None, self.pandas_categorical, is_predict=True)[0] + data = _data_from_pandas(data, self.feature_name(), None, self.pandas_categorical, validate_features=validate_features)[0] predictor = self._to_predictor(deepcopy(kwargs)) if num_iteration is None: if start_iteration <= 0: diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index fa1769897736..39ccb800d98c 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -802,7 +802,7 @@ def _get_meta_data(collection, name, i): ) + "\n\n" + _lgbmmodel_doc_custom_eval_note def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None, - pred_leaf=False, pred_contrib=False, **kwargs): + pred_leaf=False, pred_contrib=False, validate_features=True, **kwargs): """Docstring is set after definition, using a template.""" if not self.__sklearn_is_fitted__(): raise LGBMNotFittedError("Estimator not fitted, call fit before exploiting the model.") @@ -829,7 +829,8 @@ def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None, predict_params.pop(alias, None) predict_params.update(kwargs) return self._Booster.predict(X, raw_score=raw_score, start_iteration=start_iteration, num_iteration=num_iteration, - pred_leaf=pred_leaf, pred_contrib=pred_contrib, **predict_params) + pred_leaf=pred_leaf, pred_contrib=pred_contrib, validate_features=validate_features, + **predict_params) predict.__doc__ = _lgbmmodel_doc_predict.format( description="Return the predicted value for each sample.", @@ -1063,10 +1064,12 @@ def fit( + _base_doc[_base_doc.find('eval_metric :'):]) def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None, - pred_leaf=False, pred_contrib=False, **kwargs): + pred_leaf=False, pred_contrib=False, validate_features=True, + **kwargs): """Docstring is inherited from the LGBMModel.""" result = self.predict_proba(X, raw_score, start_iteration, num_iteration, - pred_leaf, pred_contrib, **kwargs) + pred_leaf, pred_contrib, validate_features, + **kwargs) if callable(self._objective) or raw_score or pred_leaf or pred_contrib: return result else: @@ -1076,9 +1079,9 @@ def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None, predict.__doc__ = LGBMModel.predict.__doc__ def predict_proba(self, X, raw_score=False, start_iteration=0, num_iteration=None, - pred_leaf=False, pred_contrib=False, **kwargs): + pred_leaf=False, pred_contrib=False, validate_features=True, **kwargs): """Docstring is set after definition, using a template.""" - result = super().predict(X, raw_score, start_iteration, num_iteration, pred_leaf, pred_contrib, **kwargs) + result = super().predict(X, raw_score, start_iteration, num_iteration, pred_leaf, pred_contrib, validate_features, **kwargs) if callable(self._objective) and not (raw_score or pred_leaf or pred_contrib): _log_warning("Cannot compute class probabilities or labels " "due to the usage of customized objective function.\n" diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index c8d6107e2ce8..8218fc794f2b 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -12,7 +12,7 @@ import lightgbm as lgb from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame, pd_Series -from .utils import load_breast_cancer +from .utils import load_breast_cancer, make_synthetic_regression def test_basic(tmp_path): @@ -583,25 +583,27 @@ def test_param_aliases(): @pytest.mark.skipif(not PANDAS_INSTALLED, reason='pandas is not installed') def test_predict_with_dataframe_checks_features(): - df = pd_DataFrame( - { - 'x1': np.random.rand(100), - 'x2': np.random.rand(100), - 'x3': np.random.rand(100), - 'y': np.random.rand(100), - } - ) - features = ['x1', 'x2', 'x3'] - ds = lgb.Dataset(df[features], df['y']) + X, y = make_synthetic_regression() + features = ['x1', 'x2', 'x3', 'x4'] + df = pd_DataFrame(X, columns=features) + ds = lgb.Dataset(df, y) bst = lgb.train({'num_leaves': 15, 'verbose': -1}, ds, num_boost_round=10) assert bst.feature_name() == features # try to predict with a different feature df2 = df.rename(columns={'x1': 'z'}) with pytest.raises(ValueError, match="The following features are missing: {'x1'}"): - bst.predict(df2[['z', 'x2', 'x3']]) + bst.predict(df2) + + # check that disabling the check doesn't raise the error + bst.predict(df2, validate_features=False) # predict with the features out of order preds_sorted_features = bst.predict(df[features]) - preds_out_of_order_features = bst.predict(df[['x3', 'x1', 'x2']]) - np.testing.assert_equal(preds_sorted_features, preds_out_of_order_features) + scrambled_features = ['x3', 'x1', 'x4', 'x2'] + preds_scrambled_features = bst.predict(df[scrambled_features]) + np.testing.assert_equal(preds_sorted_features, preds_scrambled_features) + + # check that disabling the check doesn't raise an error and produces incorrect predictions + preds_scrambled_features_no_check = bst.predict(df[scrambled_features], validate_features=False) + assert any(preds_sorted_features != preds_scrambled_features_no_check) diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index d4112078f39e..c77cd024aac9 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -17,6 +17,7 @@ from sklearn.utils.validation import check_is_fitted import lightgbm as lgb +from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame from .utils import (load_boston, load_breast_cancer, load_digits, load_iris, load_linnerud, make_ranking, make_synthetic_regression) @@ -31,6 +32,24 @@ from sklearn.utils.estimator_checks import parametrize_with_checks decreasing_generator = itertools.count(0, -1) +task_to_model_factory = { + 'ranking': lgb.LGBMRanker, + 'classification': lgb.LGBMClassifier, + 'regression': lgb.LGBMRegressor, +} + + +def _create_data(task): + if task == 'ranking': + X, y, g = make_ranking(n_features=4) + g = np.bincount(g) + elif task == 'classification': + X, y = load_iris(return_X_y=True) + g = None + elif task == 'regression': + X, y = make_synthetic_regression() + g = None + return X, y, g class UnpicklableCallback: @@ -1325,16 +1344,7 @@ def test_parameters_default_constructible(estimator): @pytest.mark.parametrize('task', ['classification', 'ranking', 'regression']) def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(task): pd = pytest.importorskip("pandas") - if task == 'ranking': - X, y, g = make_ranking() - g = np.bincount(g) - model_factory = lgb.LGBMRanker - elif task == 'classification': - X, y = load_iris(return_X_y=True) - model_factory = lgb.LGBMClassifier - elif task == 'regression': - X, y = make_synthetic_regression() - model_factory = lgb.LGBMRegressor + X, y, g = _create_data(task) X = pd.DataFrame(X) y_col_array = y.reshape(-1, 1) params = { @@ -1342,6 +1352,7 @@ def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(task 'num_leaves': 3, 'random_state': 0 } + model_factory = task_to_model_factory[task] with pytest.warns(UserWarning, match='column-vector'): if task == 'ranking': model_1d = model_factory(**params).fit(X, y, group=g) @@ -1353,3 +1364,35 @@ def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(task preds_1d = model_1d.predict(X) preds_2d = model_2d.predict(X) np.testing.assert_array_equal(preds_1d, preds_2d) + + +@pytest.mark.skipif(not PANDAS_INSTALLED, reason='pandas is not installed') +@pytest.mark.parametrize('task', ['classification', 'ranking', 'regression']) +def test_validate_features(task): + X, y, g = _create_data(task) + features = ['x1', 'x2', 'x3', 'x4'] + df = pd_DataFrame(X, columns=features) + model = task_to_model_factory[task](n_estimators=10, num_leaves=15, verbose=-1) + if task == 'ranking': + model.fit(df, y, group=g) + else: + model.fit(df, y) + assert model.booster_.feature_name() == features + + # try to predict with a different feature + df2 = df.rename(columns={'x1': 'z'}) + with pytest.raises(ValueError, match="The following features are missing: {'x1'}"): + model.predict(df2) + + # check that disabling the check doesn't raise the error + model.predict(df2, validate_features=False) + + # predict with the features out of order + preds_sorted_features = model.predict(df[features]) + scrambled_features = ['x3', 'x1', 'x4', 'x2'] + preds_scrambled_features = model.predict(df[scrambled_features]) + np.testing.assert_equal(preds_sorted_features, preds_scrambled_features) + + # check that disabling the check doesn't raise an error and produces incorrect predictions + preds_scrambled_features_no_check = model.predict(df[scrambled_features], validate_features=False) + assert any(preds_sorted_features != preds_scrambled_features_no_check) From 0874157e1f50cfb4b0853b9d9917232542f64cef Mon Sep 17 00:00:00 2001 From: Jose Morales Date: Thu, 30 Dec 2021 19:47:19 -0600 Subject: [PATCH 07/19] document validate_features argument --- python-package/lightgbm/basic.py | 3 +++ python-package/lightgbm/sklearn.py | 3 +++ 2 files changed, 6 insertions(+) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index a39aac87b330..e37c255c8456 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -3493,6 +3493,9 @@ def predict(self, data, start_iteration=0, num_iteration=None, Used only if data is str. is_reshape : bool, optional (default=True) If True, result is reshaped to [nrow, ncol]. + validate_features : bool, optional (default=True) + If True, ensure that the features used to predict match the ones used to train. + Used only if data is pandas DataFrame. **kwargs Other parameters for the prediction. diff --git a/python-package/lightgbm/sklearn.py b/python-package/lightgbm/sklearn.py index 39ccb800d98c..509f68ddc3db 100644 --- a/python-package/lightgbm/sklearn.py +++ b/python-package/lightgbm/sklearn.py @@ -340,6 +340,9 @@ def __call__(self, preds, dataset): Note that unlike the shap package, with ``pred_contrib`` we return a matrix with an extra column, where the last column is the expected value. + validate_features : bool, optional (default=True) + If True, ensure that the features used to predict match the ones used to train. + Used only if data is pandas DataFrame. **kwargs Other parameters for the prediction. From 047c62127d71ad1010ebb1cbce3fd19700d04d1f Mon Sep 17 00:00:00 2001 From: Jose Morales Date: Thu, 30 Dec 2021 20:05:07 -0600 Subject: [PATCH 08/19] use all_close in preds checks and check for assertion error to compare different arrays --- tests/python_package_test/test_basic.py | 4 ++-- tests/python_package_test/test_sklearn.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index bb052677faaf..0800e0a2565d 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -632,8 +632,8 @@ def test_validate_features(): preds_sorted_features = bst.predict(df[features]) scrambled_features = ['x3', 'x1', 'x4', 'x2'] preds_scrambled_features = bst.predict(df[scrambled_features]) - np.testing.assert_equal(preds_sorted_features, preds_scrambled_features) + np.testing.assert_allclose(preds_sorted_features, preds_scrambled_features) # check that disabling the check doesn't raise an error and produces incorrect predictions preds_scrambled_features_no_check = bst.predict(df[scrambled_features], validate_features=False) - assert any(preds_sorted_features != preds_scrambled_features_no_check) \ No newline at end of file + np.testing.assert_raises(AssertionError, np.testing.assert_allclose, preds_sorted_features, preds_scrambled_features_no_check) diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index c77cd024aac9..3855eede6af4 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -1391,8 +1391,8 @@ def test_validate_features(task): preds_sorted_features = model.predict(df[features]) scrambled_features = ['x3', 'x1', 'x4', 'x2'] preds_scrambled_features = model.predict(df[scrambled_features]) - np.testing.assert_equal(preds_sorted_features, preds_scrambled_features) + np.testing.assert_allclose(preds_sorted_features, preds_scrambled_features) # check that disabling the check doesn't raise an error and produces incorrect predictions preds_scrambled_features_no_check = model.predict(df[scrambled_features], validate_features=False) - assert any(preds_sorted_features != preds_scrambled_features_no_check) + np.testing.assert_raises(AssertionError, np.testing.assert_allclose, preds_sorted_features, preds_scrambled_features_no_check) From f85c25f4ac7e27cb15d4c177c9fd9e445ae418b5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Sun, 27 Feb 2022 11:08:13 -0600 Subject: [PATCH 09/19] perform remapping and checks in cpp --- include/LightGBM/c_api.h | 16 ++++ python-package/lightgbm/basic.py | 101 ++++++++++++++-------- src/c_api.cpp | 85 ++++++++++++++++++ tests/python_package_test/test_basic.py | 28 ------ tests/python_package_test/test_engine.py | 28 ++++++ tests/python_package_test/test_sklearn.py | 2 +- 6 files changed, 195 insertions(+), 65 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index ed639b7f298c..a830347c9732 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -1033,6 +1033,22 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle, int64_t* out_len, double* out_result); +/*! + */ +LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForNamedMat(BoosterHandle handle, + const void* data, + const char** names, + int data_type, + int32_t nrow, + int32_t ncol, + int is_row_major, + int predict_type, + int start_iteration, + int num_iteration, + const char* parameter, + int64_t* out_len, + double* out_result); + /*! * \brief Make prediction for a new dataset. * \note diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 9f48c8d81474..beea3bfcd670 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -512,23 +512,12 @@ def _get_bad_pandas_dtypes(dtypes): return bad_indices -def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical, validate_features=False): +def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical): if isinstance(data, pd_DataFrame): if len(data.shape) != 2 or data.shape[0] < 1: raise ValueError('Input data must be 2 dimensional and non empty.') if feature_name == 'auto' or feature_name is None: data = data.rename(columns=str) - elif isinstance(feature_name, list) and validate_features: - df_features = [str(x) for x in data.columns] - missing_features = set(feature_name) - set(df_features) - if missing_features: - raise ValueError( - f"The following features are missing: {missing_features}.\n" - "If you're sure the features are correct you can disable this check by setting validate_features=False" - ) - sort_idxs = [df_features.index(feature) for feature in feature_name] - if not all(x == i for i, x in enumerate(sort_idxs)): - data = data.iloc[:, sort_idxs] # ensure column order cat_cols = [col for col, dtype in zip(data.columns, data.dtypes) if isinstance(dtype, pd_CategoricalDtype)] cat_cols_not_ordered = [col for col in cat_cols if not data[col].cat.ordered] if pandas_categorical is None: # train dataset @@ -746,9 +735,18 @@ def __getstate__(self): this.pop('handle', None) return this - def predict(self, data, start_iteration=0, num_iteration=-1, - raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False, - is_reshape=True): + def predict( + self, + data, + start_iteration=0, + num_iteration=-1, + raw_score=False, + pred_leaf=False, + pred_contrib=False, + data_has_header=False, + is_reshape=True, + data_names=None, + ): """Predict logic. Parameters @@ -805,15 +803,15 @@ def predict(self, data, start_iteration=0, num_iteration=-1, elif isinstance(data, scipy.sparse.csc_matrix): preds, nrow = self.__pred_for_csc(data, start_iteration, num_iteration, predict_type) elif isinstance(data, np.ndarray): - preds, nrow = self.__pred_for_np2d(data, start_iteration, num_iteration, predict_type) + preds, nrow = self.__pred_for_np2d(data, start_iteration, num_iteration, predict_type, data_names) elif isinstance(data, list): try: data = np.array(data) except BaseException: raise ValueError('Cannot convert data list to numpy array.') - preds, nrow = self.__pred_for_np2d(data, start_iteration, num_iteration, predict_type) + preds, nrow = self.__pred_for_np2d(data, start_iteration, num_iteration, predict_type, None) elif isinstance(data, dt_DataTable): - preds, nrow = self.__pred_for_np2d(data.to_numpy(), start_iteration, num_iteration, predict_type) + preds, nrow = self.__pred_for_np2d(data.to_numpy(), start_iteration, num_iteration, predict_type, None) else: try: _log_warning('Converting data to scipy sparse matrix.') @@ -848,7 +846,7 @@ def __get_num_preds(self, start_iteration, num_iteration, nrow, predict_type): ctypes.byref(n_preds))) return n_preds.value - def __pred_for_np2d(self, mat, start_iteration, num_iteration, predict_type): + def __pred_for_np2d(self, mat, start_iteration, num_iteration, predict_type, data_names): """Predict for a 2-D numpy matrix.""" if len(mat.shape) != 2: raise ValueError('Input numpy.ndarray or list must be 2 dimensional') @@ -865,19 +863,38 @@ def inner_predict(mat, start_iteration, num_iteration, predict_type, preds=None) elif len(preds.shape) != 1 or len(preds) != n_preds: raise ValueError("Wrong length of pre-allocated predict array") out_num_preds = ctypes.c_int64(0) - _safe_call(_LIB.LGBM_BoosterPredictForMat( - self.handle, - ptr_data, - ctypes.c_int(type_ptr_data), - ctypes.c_int32(mat.shape[0]), - ctypes.c_int32(mat.shape[1]), - ctypes.c_int(C_API_IS_ROW_MAJOR), - ctypes.c_int(predict_type), - ctypes.c_int(start_iteration), - ctypes.c_int(num_iteration), - c_str(self.pred_parameter), - ctypes.byref(out_num_preds), - preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) + if data_names is None: + _safe_call(_LIB.LGBM_BoosterPredictForMat( + self.handle, + ptr_data, + ctypes.c_int(type_ptr_data), + ctypes.c_int32(mat.shape[0]), + ctypes.c_int32(mat.shape[1]), + ctypes.c_int(C_API_IS_ROW_MAJOR), + ctypes.c_int(predict_type), + ctypes.c_int(start_iteration), + ctypes.c_int(num_iteration), + c_str(self.pred_parameter), + ctypes.byref(out_num_preds), + preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) + else: + ptr_names = (ctypes.c_char_p * len(data_names))() + ptr_names[:] = [n.encode('utf-8') for n in data_names] + print(data_names) + _safe_call(_LIB.LGBM_BoosterPredictForNamedMat( + self.handle, + ptr_data, + ptr_names, + ctypes.c_int(type_ptr_data), + ctypes.c_int32(mat.shape[0]), + ctypes.c_int32(mat.shape[1]), + ctypes.c_int(C_API_IS_ROW_MAJOR), + ctypes.c_int(predict_type), + ctypes.c_int(start_iteration), + ctypes.c_int(num_iteration), + c_str(self.pred_parameter), + ctypes.byref(out_num_preds), + preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) if n_preds != out_num_preds.value: raise ValueError("Wrong length for predict results") return preds, mat.shape[0] @@ -3515,16 +3532,28 @@ def predict(self, data, start_iteration=0, num_iteration=None, """ if isinstance(data, Dataset): raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead") - data = _data_from_pandas(data, self.feature_name(), None, self.pandas_categorical, validate_features=validate_features)[0] + elif isinstance(data, pd_DataFrame) and validate_features: + data_names = [str(col) for col in data.columns] + else: + data_names = None + data = _data_from_pandas(data, None, None, self.pandas_categorical)[0] predictor = self._to_predictor(deepcopy(kwargs)) if num_iteration is None: if start_iteration <= 0: num_iteration = self.best_iteration else: num_iteration = -1 - return predictor.predict(data, start_iteration, num_iteration, - raw_score, pred_leaf, pred_contrib, - data_has_header, is_reshape) + return predictor.predict( + data, + start_iteration, + num_iteration, + raw_score, + pred_leaf, + pred_contrib, + data_has_header, + is_reshape, + data_names, + ) def refit(self, data, label, decay_rate=0.9, **kwargs): """Refit the existing Booster by new data. diff --git a/src/c_api.cpp b/src/c_api.cpp index d8a8deaf57b0..1295bb4cc5e7 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -2134,6 +2134,91 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle, API_END(); } + +template +void remap(const void* input, int32_t nrow, int32_t ncol, const std::unordered_map &remappings, int is_row_major, T* ptr_output) { + const T* ptr_input = static_cast(input); + for (auto& it : remappings) { + int src = it.first; + int dest = it.second; + Log::Info("Writing %d to %d", src, dest); + if (is_row_major) { + for (int32_t i = 0; i < nrow; ++i) { + ptr_output[dest + ncol * i] = ptr_input[src + ncol * i]; + } + } else { + for (int32_t i = 0; i < nrow; ++i) { + ptr_output[dest * ncol + i] = ptr_input[src * ncol + i]; + } + } + } + Log::Info("Final"); + for (int i = 0; i < nrow * ncol; ++i) { + Log::Info("%f", ptr_output[i]); + } +} + + +int LGBM_BoosterPredictForNamedMat(BoosterHandle handle, + const void* data, + const char** names, + int data_type, + int32_t nrow, + int32_t ncol, + int is_row_major, + int predict_type, + int start_iteration, + int num_iteration, + const char* parameter, + int64_t* out_len, + double* out_result) { + API_BEGIN(); + int features_out_len; + size_t out_buffer_len; + LGBM_BoosterGetFeatureNames(handle, 0, &features_out_len, 0, &out_buffer_len, nullptr); + std::vector> tmp_names(features_out_len); + std::vector fnames(features_out_len); + for (int i = 0; i < features_out_len; ++i) { + tmp_names[i].resize(out_buffer_len); + fnames[i] = tmp_names[i].data(); + } + size_t allocated_buffer_len = out_buffer_len; + int expected_len = static_cast(ncol); + LGBM_BoosterGetFeatureNames(handle, expected_len, &features_out_len, allocated_buffer_len, &out_buffer_len, fnames.data()); + + std::unordered_map feature_positions; + for (int i = 0; i < ncol; ++i) { + feature_positions[names[i]] = i; + } + std::unordered_map remappings(ncol); + for (int i = 0; i < ncol; ++i) { + auto it = feature_positions.find(fnames[i]); + if (it == feature_positions.end()) { + Log::Fatal("%s not found in data", fnames[i]); + } + int pos = it->second; + if (pos != i) { + Log::Info("%s found at position %d, expected %d", fnames[i], pos, i); + remappings[pos] = i; + } + } + if (remappings.size() > 0) { // needs adjust + if (data_type == C_API_DTYPE_FLOAT32) { + float* ptr_remapped = (float*) malloc(nrow * ncol * sizeof(float)); + remap(data, nrow, ncol, remappings, is_row_major, ptr_remapped); + LGBM_BoosterPredictForMat(handle, ptr_remapped, data_type, nrow, ncol, is_row_major, predict_type, start_iteration, num_iteration, parameter, out_len, out_result); + } else { + double* ptr_remapped = (double*) malloc(nrow * ncol * sizeof(double)); + remap(data, nrow, ncol, remappings, is_row_major, ptr_remapped); + LGBM_BoosterPredictForMat(handle, ptr_remapped, data_type, nrow, ncol, is_row_major, predict_type, start_iteration, num_iteration, parameter, out_len, out_result); + } + } else { + LGBM_BoosterPredictForMat(handle, data, data_type, nrow, ncol, is_row_major, predict_type, start_iteration, num_iteration, parameter, out_len, out_result); + } + API_END(); +} + + int LGBM_BoosterPredictForMatSingleRow(BoosterHandle handle, const void* data, int data_type, diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 0800e0a2565d..aa77ace3aab6 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -609,31 +609,3 @@ def test_custom_objective_safety(): good_bst_multi.update(fobj=_good_gradients) with pytest.raises(ValueError, match=re.escape(f"number of models per one iteration ({nclass})")): bad_bst_multi.update(fobj=_bad_gradients) - - -@pytest.mark.skipif(not PANDAS_INSTALLED, reason='pandas is not installed') -def test_validate_features(): - X, y = make_synthetic_regression() - features = ['x1', 'x2', 'x3', 'x4'] - df = pd_DataFrame(X, columns=features) - ds = lgb.Dataset(df, y) - bst = lgb.train({'num_leaves': 15, 'verbose': -1}, ds, num_boost_round=10) - assert bst.feature_name() == features - - # try to predict with a different feature - df2 = df.rename(columns={'x1': 'z'}) - with pytest.raises(ValueError, match="The following features are missing: {'x1'}"): - bst.predict(df2) - - # check that disabling the check doesn't raise the error - bst.predict(df2, validate_features=False) - - # predict with the features out of order - preds_sorted_features = bst.predict(df[features]) - scrambled_features = ['x3', 'x1', 'x4', 'x2'] - preds_scrambled_features = bst.predict(df[scrambled_features]) - np.testing.assert_allclose(preds_sorted_features, preds_scrambled_features) - - # check that disabling the check doesn't raise an error and produces incorrect predictions - preds_scrambled_features_no_check = bst.predict(df[scrambled_features], validate_features=False) - np.testing.assert_raises(AssertionError, np.testing.assert_allclose, preds_sorted_features, preds_scrambled_features_no_check) diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index a74056b2c948..28000956d522 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -3224,3 +3224,31 @@ def test_force_split_with_feature_fraction(tmp_path): for tree in tree_info: tree_structure = tree["tree_structure"] assert tree_structure['split_feature'] == 0 + + +@pytest.mark.skipif(not PANDAS_INSTALLED, reason='pandas is not installed') +def test_validate_features(): + X, y = make_synthetic_regression() + features = ['x1', 'x2', 'x3', 'x4'] + df = pd_DataFrame(X, columns=features) + ds = lgb.Dataset(df, y) + bst = lgb.train({'num_leaves': 15, 'verbose': -1}, ds, num_boost_round=10) + assert bst.feature_name() == features + + # try to predict with a different feature + df2 = df.rename(columns={'x1': 'z'}) + with pytest.raises(lgb.basic.LightGBMError, match="x1 not found in data"): + bst.predict(df2) + + # check that disabling the check doesn't raise the error + bst.predict(df2, validate_features=False) + + # predict with the features out of order + preds_sorted_features = bst.predict(df[features]) + scrambled_features = ['x3', 'x1', 'x4', 'x2'] + preds_scrambled_features = bst.predict(df[scrambled_features]) + np.testing.assert_allclose(preds_sorted_features, preds_scrambled_features) + + # check that disabling the check doesn't raise an error and produces incorrect predictions + preds_scrambled_features_no_check = bst.predict(df[scrambled_features], validate_features=False) + np.testing.assert_raises(AssertionError, np.testing.assert_allclose, preds_sorted_features, preds_scrambled_features_no_check) diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 3855eede6af4..cd70318c39df 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -1381,7 +1381,7 @@ def test_validate_features(task): # try to predict with a different feature df2 = df.rename(columns={'x1': 'z'}) - with pytest.raises(ValueError, match="The following features are missing: {'x1'}"): + with pytest.raises(lgb.basic.LightGBMError, match="x1 not found in data"): model.predict(df2) # check that disabling the check doesn't raise the error From dbcec101c78e2cc581524b9aa79710a4cf386e17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Sun, 27 Feb 2022 11:10:20 -0600 Subject: [PATCH 10/19] remove extra logs --- src/c_api.cpp | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/c_api.cpp b/src/c_api.cpp index 1295bb4cc5e7..e01ca6cf0c4d 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -2141,7 +2141,6 @@ void remap(const void* input, int32_t nrow, int32_t ncol, const std::unordered_m for (auto& it : remappings) { int src = it.first; int dest = it.second; - Log::Info("Writing %d to %d", src, dest); if (is_row_major) { for (int32_t i = 0; i < nrow; ++i) { ptr_output[dest + ncol * i] = ptr_input[src + ncol * i]; @@ -2152,10 +2151,6 @@ void remap(const void* input, int32_t nrow, int32_t ncol, const std::unordered_m } } } - Log::Info("Final"); - for (int i = 0; i < nrow * ncol; ++i) { - Log::Info("%f", ptr_output[i]); - } } From 2469149716ef7f79a2503d45d9fcfc7d2bb9545d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Sun, 27 Feb 2022 11:23:52 -0600 Subject: [PATCH 11/19] fixes --- src/c_api.cpp | 2 +- tests/python_package_test/test_basic.py | 4 ++-- tests/python_package_test/test_engine.py | 1 + tests/python_package_test/test_sklearn.py | 2 +- 4 files changed, 5 insertions(+), 4 deletions(-) diff --git a/src/c_api.cpp b/src/c_api.cpp index c3b1456052c4..ec78107b77b8 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -2138,7 +2138,7 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle, template void remap(const void* input, int32_t nrow, int32_t ncol, const std::unordered_map &remappings, int is_row_major, T* ptr_output) { const T* ptr_input = static_cast(input); - for (auto& it : remappings) { + for (const auto& it : remappings) { int src = it.first; int dest = it.second; if (is_row_major) { diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 7a841cd7ec7e..e49a6009d0a3 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -8,13 +8,13 @@ import pytest from scipy import sparse from sklearn.datasets import dump_svmlight_file, load_svmlight_file, make_blobs -from sklearn.metrics import log_loss, mean_squared_error +from sklearn.metrics import log_loss from sklearn.model_selection import train_test_split import lightgbm as lgb from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame, pd_Series -from .utils import load_breast_cancer, make_synthetic_regression, sklearn_multiclass_custom_objective, softmax +from .utils import load_breast_cancer, sklearn_multiclass_custom_objective, softmax def test_basic(tmp_path): diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index f8949b08a39d..f6ef7a70ecfc 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -17,6 +17,7 @@ from sklearn.model_selection import GroupKFold, TimeSeriesSplit, train_test_split import lightgbm as lgb +from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame from .utils import load_boston, load_breast_cancer, load_digits, load_iris, make_synthetic_regression diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index bfba78fb83f9..54eb00060a5e 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -1339,4 +1339,4 @@ def test_validate_features(task): # check that disabling the check doesn't raise an error and produces incorrect predictions preds_scrambled_features_no_check = model.predict(df[scrambled_features], validate_features=False) - np.testing.assert_raises(AssertionError, np.testing.assert_allclose, preds_sorted_features, preds_scrambled_features_no_check) \ No newline at end of file + np.testing.assert_raises(AssertionError, np.testing.assert_allclose, preds_sorted_features, preds_scrambled_features_no_check) From 91f2f8e55a173ed671a6172fee4466de73d3ab01 Mon Sep 17 00:00:00 2001 From: Jose Morales Date: Tue, 1 Mar 2022 17:34:01 -0600 Subject: [PATCH 12/19] revert cpp --- include/LightGBM/c_api.h | 16 ---- python-package/lightgbm/basic.py | 101 ++++++++-------------- src/c_api.cpp | 6 ++ tests/python_package_test/test_basic.py | 1 + tests/python_package_test/test_engine.py | 1 + tests/python_package_test/test_sklearn.py | 2 +- 6 files changed, 45 insertions(+), 82 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index d6def384b737..cfce5c08f993 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -1033,22 +1033,6 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForCSC(BoosterHandle handle, int64_t* out_len, double* out_result); -/*! - */ -LIGHTGBM_C_EXPORT int LGBM_BoosterPredictForNamedMat(BoosterHandle handle, - const void* data, - const char** names, - int data_type, - int32_t nrow, - int32_t ncol, - int is_row_major, - int predict_type, - int start_iteration, - int num_iteration, - const char* parameter, - int64_t* out_len, - double* out_result); - /*! * \brief Make prediction for a new dataset. * \note diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 4e7e7a6d606d..69325cbf0cd5 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -512,12 +512,23 @@ def is_allowed_numpy_dtype(dtype): return [i for i, dtype in enumerate(dtypes) if not is_allowed_numpy_dtype(dtype.type)] -def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical): +def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical, validate_features=False): if isinstance(data, pd_DataFrame): if len(data.shape) != 2 or data.shape[0] < 1: raise ValueError('Input data must be 2 dimensional and non empty.') if feature_name == 'auto' or feature_name is None: data = data.rename(columns=str) + elif isinstance(feature_name, list) and validate_features: + df_features = [str(x) for x in data.columns] + missing_features = set(feature_name) - set(df_features) + if missing_features: + raise ValueError( + f"The following features are missing: {missing_features}.\n" + "If you're sure the features are correct you can disable this check by setting validate_features=False" + ) + sort_idxs = [df_features.index(feature) for feature in feature_name] + if not all(x == i for i, x in enumerate(sort_idxs)): + data = data.iloc[:, sort_idxs] # ensure column order cat_cols = [col for col, dtype in zip(data.columns, data.dtypes) if isinstance(dtype, pd_CategoricalDtype)] cat_cols_not_ordered = [col for col in cat_cols if not data[col].cat.ordered] if pandas_categorical is None: # train dataset @@ -736,18 +747,9 @@ def __getstate__(self): this.pop('handle', None) return this - def predict( - self, - data, - start_iteration=0, - num_iteration=-1, - raw_score=False, - pred_leaf=False, - pred_contrib=False, - data_has_header=False, - is_reshape=True, - data_names=None, - ): + def predict(self, data, start_iteration=0, num_iteration=-1, + raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False, + is_reshape=True): """Predict logic. Parameters @@ -804,15 +806,15 @@ def predict( elif isinstance(data, scipy.sparse.csc_matrix): preds, nrow = self.__pred_for_csc(data, start_iteration, num_iteration, predict_type) elif isinstance(data, np.ndarray): - preds, nrow = self.__pred_for_np2d(data, start_iteration, num_iteration, predict_type, data_names) + preds, nrow = self.__pred_for_np2d(data, start_iteration, num_iteration, predict_type) elif isinstance(data, list): try: data = np.array(data) except BaseException: raise ValueError('Cannot convert data list to numpy array.') - preds, nrow = self.__pred_for_np2d(data, start_iteration, num_iteration, predict_type, None) + preds, nrow = self.__pred_for_np2d(data, start_iteration, num_iteration, predict_type) elif isinstance(data, dt_DataTable): - preds, nrow = self.__pred_for_np2d(data.to_numpy(), start_iteration, num_iteration, predict_type, None) + preds, nrow = self.__pred_for_np2d(data.to_numpy(), start_iteration, num_iteration, predict_type) else: try: _log_warning('Converting data to scipy sparse matrix.') @@ -847,7 +849,7 @@ def __get_num_preds(self, start_iteration, num_iteration, nrow, predict_type): ctypes.byref(n_preds))) return n_preds.value - def __pred_for_np2d(self, mat, start_iteration, num_iteration, predict_type, data_names): + def __pred_for_np2d(self, mat, start_iteration, num_iteration, predict_type): """Predict for a 2-D numpy matrix.""" if len(mat.shape) != 2: raise ValueError('Input numpy.ndarray or list must be 2 dimensional') @@ -864,38 +866,19 @@ def inner_predict(mat, start_iteration, num_iteration, predict_type, preds=None) elif len(preds.shape) != 1 or len(preds) != n_preds: raise ValueError("Wrong length of pre-allocated predict array") out_num_preds = ctypes.c_int64(0) - if data_names is None: - _safe_call(_LIB.LGBM_BoosterPredictForMat( - self.handle, - ptr_data, - ctypes.c_int(type_ptr_data), - ctypes.c_int32(mat.shape[0]), - ctypes.c_int32(mat.shape[1]), - ctypes.c_int(C_API_IS_ROW_MAJOR), - ctypes.c_int(predict_type), - ctypes.c_int(start_iteration), - ctypes.c_int(num_iteration), - c_str(self.pred_parameter), - ctypes.byref(out_num_preds), - preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) - else: - ptr_names = (ctypes.c_char_p * len(data_names))() - ptr_names[:] = [n.encode('utf-8') for n in data_names] - print(data_names) - _safe_call(_LIB.LGBM_BoosterPredictForNamedMat( - self.handle, - ptr_data, - ptr_names, - ctypes.c_int(type_ptr_data), - ctypes.c_int32(mat.shape[0]), - ctypes.c_int32(mat.shape[1]), - ctypes.c_int(C_API_IS_ROW_MAJOR), - ctypes.c_int(predict_type), - ctypes.c_int(start_iteration), - ctypes.c_int(num_iteration), - c_str(self.pred_parameter), - ctypes.byref(out_num_preds), - preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) + _safe_call(_LIB.LGBM_BoosterPredictForMat( + self.handle, + ptr_data, + ctypes.c_int(type_ptr_data), + ctypes.c_int32(mat.shape[0]), + ctypes.c_int32(mat.shape[1]), + ctypes.c_int(C_API_IS_ROW_MAJOR), + ctypes.c_int(predict_type), + ctypes.c_int(start_iteration), + ctypes.c_int(num_iteration), + c_str(self.pred_parameter), + ctypes.byref(out_num_preds), + preds.ctypes.data_as(ctypes.POINTER(ctypes.c_double)))) if n_preds != out_num_preds.value: raise ValueError("Wrong length for predict results") return preds, mat.shape[0] @@ -3528,28 +3511,16 @@ def predict(self, data, start_iteration=0, num_iteration=None, """ if isinstance(data, Dataset): raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead") - elif isinstance(data, pd_DataFrame) and validate_features: - data_names = [str(col) for col in data.columns] - else: - data_names = None - data = _data_from_pandas(data, None, None, self.pandas_categorical)[0] + data = _data_from_pandas(data, self.feature_name(), None, self.pandas_categorical, validate_features=validate_features)[0] predictor = self._to_predictor(deepcopy(kwargs)) if num_iteration is None: if start_iteration <= 0: num_iteration = self.best_iteration else: num_iteration = -1 - return predictor.predict( - data, - start_iteration, - num_iteration, - raw_score, - pred_leaf, - pred_contrib, - data_has_header, - is_reshape, - data_names, - ) + return predictor.predict(data, start_iteration, num_iteration, + raw_score, pred_leaf, pred_contrib, + data_has_header, is_reshape) def refit( self, diff --git a/src/c_api.cpp b/src/c_api.cpp index ec78107b77b8..48716ebaec89 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -2134,6 +2134,7 @@ int LGBM_BoosterPredictForMat(BoosterHandle handle, API_END(); } +<<<<<<< HEAD template void remap(const void* input, int32_t nrow, int32_t ncol, const std::unordered_map &remappings, int is_row_major, T* ptr_output) { @@ -2141,6 +2142,7 @@ void remap(const void* input, int32_t nrow, int32_t ncol, const std::unordered_m for (const auto& it : remappings) { int src = it.first; int dest = it.second; + Log::Info("Writing %d to %d", src, dest); if (is_row_major) { for (int32_t i = 0; i < nrow; ++i) { ptr_output[dest + ncol * i] = ptr_input[src + ncol * i]; @@ -2151,6 +2153,10 @@ void remap(const void* input, int32_t nrow, int32_t ncol, const std::unordered_m } } } + Log::Info("Final"); + for (int i = 0; i < nrow * ncol; ++i) { + Log::Info("%f", ptr_output[i]); + } } diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index e49a6009d0a3..29f17163920a 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -612,6 +612,7 @@ def test_custom_objective_safety(): bad_bst_multi.update(fobj=_bad_gradients) +<<<<<<< HEAD def test_multiclass_custom_objective(): def custom_obj(y_pred, ds): y_true = ds.get_label() diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index f6ef7a70ecfc..0dbe2dc0d45b 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -3253,6 +3253,7 @@ def test_force_split_with_feature_fraction(tmp_path): for tree in tree_info: tree_structure = tree["tree_structure"] assert tree_structure['split_feature'] == 0 +<<<<<<< HEAD def test_record_evaluation_with_train(): diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 54eb00060a5e..145f033c49be 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -1325,7 +1325,7 @@ def test_validate_features(task): # try to predict with a different feature df2 = df.rename(columns={'x1': 'z'}) - with pytest.raises(lgb.basic.LightGBMError, match="x1 not found in data"): + with pytest.raises(ValueError, match="The following features are missing: {'x1'}"): model.predict(df2) # check that disabling the check doesn't raise the error From bdce26b5dd1f1bd69b3e3189cba4d8b723d53c86 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Wed, 25 May 2022 17:57:23 -0500 Subject: [PATCH 13/19] proposal --- include/LightGBM/c_api.h | 11 +++++++++ python-package/lightgbm/basic.py | 29 +++++++++++------------ src/c_api.cpp | 25 +++++++++++++++++++ tests/python_package_test/test_engine.py | 14 ++--------- tests/python_package_test/test_sklearn.py | 14 ++--------- 5 files changed, 54 insertions(+), 39 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index 22f8807ac3e8..ac30bf311f00 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -678,6 +678,17 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetFeatureNames(BoosterHandle handle, size_t* out_buffer_len, char** out_strs); +/*! + * \brief Check that the feature names of the data match the ones used to train the booster + * \param handle Handle of booster + * \param data_names array with the feature names in the data + * \param data_num_features number of features in the data + * \return 0 when succeed, -1 when failure happens + */ +LIGHTGBM_C_EXPORT int LGBM_BoosterValidateFeatureNames(BoosterHandle handle, + const char** data_names, + int data_num_features); + /*! * \brief Get number of features. * \param handle Handle of booster diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index ea1c4420eec2..e46708a6bd13 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -532,23 +532,12 @@ def is_allowed_numpy_dtype(dtype): f'Fields with bad pandas dtypes: {", ".join(bad_pandas_dtypes)}') -def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical, validate_features): +def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical): if isinstance(data, pd_DataFrame): if len(data.shape) != 2 or data.shape[0] < 1: raise ValueError('Input data must be 2 dimensional and non empty.') if feature_name == 'auto' or feature_name is None: data = data.rename(columns=str) - elif isinstance(feature_name, list) and validate_features: - df_features = [str(x) for x in data.columns] - missing_features = set(feature_name) - set(df_features) - if missing_features: - raise ValueError( - f"The following features are missing: {missing_features}.\n" - "If you're sure the features are correct you can disable this check by setting validate_features=False" - ) - sort_idxs = [df_features.index(feature) for feature in feature_name] - if not all(x == i for i, x in enumerate(sort_idxs)): - data = data.iloc[:, sort_idxs] # ensure column order cat_cols = [col for col, dtype in zip(data.columns, data.dtypes) if isinstance(dtype, pd_CategoricalDtype)] cat_cols_not_ordered = [col for col in cat_cols if not data[col].cat.ordered] if pandas_categorical is None: # train dataset @@ -1445,8 +1434,7 @@ def _lazy_init(self, data, label=None, reference=None, data, feature_name, categorical_feature, self.pandas_categorical = _data_from_pandas(data, feature_name, categorical_feature, - self.pandas_categorical, - False) + self.pandas_categorical) label = _label_from_pandas(label) # process for args @@ -3548,7 +3536,18 @@ def predict(self, data, start_iteration=0, num_iteration=None, """ if isinstance(data, Dataset): raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead") - data = _data_from_pandas(data, self.feature_name(), None, self.pandas_categorical, validate_features)[0] + elif isinstance(data, pd_DataFrame) and validate_features: + data_names = [str(x) for x in data.columns] + ptr_names = (ctypes.c_char_p * len(data_names))() + ptr_names[:] = [x.encode('utf-8') for x in data_names] + _safe_call( + _LIB.LGBM_BoosterValidateFeatureNames( + self.handle, + ptr_names, + ctypes.c_int(len(data_names)), + ) + ) + data = _data_from_pandas(data, self.feature_name(), None, self.pandas_categorical)[0] predictor = self._to_predictor(deepcopy(kwargs)) if num_iteration is None: if start_iteration <= 0: diff --git a/src/c_api.cpp b/src/c_api.cpp index 75fb2e4bd249..d9f474848822 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -2129,6 +2129,31 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle, API_END(); } +int LGBM_BoosterValidateFeatureNames(BoosterHandle handle, + const char** data_names, + int data_num_features) { + API_BEGIN(); + int booster_num_features; + size_t out_buffer_len; + LGBM_BoosterGetFeatureNames(handle, 0, &booster_num_features, 0, &out_buffer_len, nullptr); + if (booster_num_features != data_num_features) { + Log::Fatal("Booster was trained on %d features, but got %d input features to predict.", booster_num_features, data_num_features); + } + std::vector> tmp_names(booster_num_features); + std::vector booster_names(booster_num_features); + for (int i = 0; i < booster_num_features; ++i) { + tmp_names[i].resize(out_buffer_len); + booster_names[i] = tmp_names[i].data(); + } + LGBM_BoosterGetFeatureNames(handle, data_num_features, &booster_num_features, out_buffer_len, &out_buffer_len, booster_names.data()); + for (int i = 0; i < booster_num_features; ++i) { + if (strcmp(data_names[i], booster_names[i]) != 0) { + Log::Fatal("Expected '%s' at position %d but found '%s'", booster_names[i], i, data_names[i]); + } + } + API_END(); +} + int LGBM_BoosterPredictForMat(BoosterHandle handle, const void* data, int data_type, diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index 350c11c9f68c..b57518bca455 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -3591,19 +3591,9 @@ def test_validate_features(): assert bst.feature_name() == features # try to predict with a different feature - df2 = df.rename(columns={'x1': 'z'}) - with pytest.raises(ValueError, match="The following features are missing: {'x1'}"): + df2 = df.rename(columns={'x3': 'z'}) + with pytest.raises(lgb.basic.LightGBMError, match="Expected 'x3' at position 2 but found 'z'"): bst.predict(df2, validate_features=True) # check that disabling the check doesn't raise the error bst.predict(df2, validate_features=False) - - # predict with the features out of order - preds_sorted_features = bst.predict(df[features]) - scrambled_features = ['x3', 'x1', 'x4', 'x2'] - preds_scrambled_features = bst.predict(df[scrambled_features], validate_features=True) - np.testing.assert_allclose(preds_sorted_features, preds_scrambled_features) - - # check that disabling the check doesn't raise an error and produces incorrect predictions - preds_scrambled_features_no_check = bst.predict(df[scrambled_features], validate_features=False) - np.testing.assert_raises(AssertionError, np.testing.assert_allclose, preds_sorted_features, preds_scrambled_features_no_check) diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index 5a6a58dc1fa1..b63ea5619660 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -1315,19 +1315,9 @@ def test_validate_features(task): assert model.booster_.feature_name() == features # try to predict with a different feature - df2 = df.rename(columns={'x1': 'z'}) - with pytest.raises(ValueError, match="The following features are missing: {'x1'}"): + df2 = df.rename(columns={'x2': 'z'}) + with pytest.raises(lgb.basic.LightGBMError, match="Expected 'x2' at position 1 but found 'z'"): model.predict(df2, validate_features=True) # check that disabling the check doesn't raise the error model.predict(df2, validate_features=False) - - # predict with the features out of order - preds_sorted_features = model.predict(df[features]) - scrambled_features = ['x3', 'x1', 'x4', 'x2'] - preds_scrambled_features = model.predict(df[scrambled_features], validate_features=True) - np.testing.assert_allclose(preds_sorted_features, preds_scrambled_features) - - # check that disabling the check doesn't raise an error and produces incorrect predictions - preds_scrambled_features_no_check = model.predict(df[scrambled_features], validate_features=False) - np.testing.assert_raises(AssertionError, np.testing.assert_allclose, preds_sorted_features, preds_scrambled_features_no_check) From a3c35d7d5a76062005199c341177112efbfe629b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Wed, 25 May 2022 18:51:11 -0500 Subject: [PATCH 14/19] remove extra arg --- tests/python_package_test/test_basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 3100eba70711..d290bcb7216c 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -650,7 +650,7 @@ def test_no_copy_when_single_float_dtype_dataframe(dtype): df = pd.DataFrame(X) # feature names are required to not make a copy (rename makes a copy) feature_name = ['x1', 'x2'] - built_data = lgb.basic._data_from_pandas(df, feature_name, None, None, False)[0] + built_data = lgb.basic._data_from_pandas(df, feature_name, None, None)[0] assert built_data.dtype == dtype assert np.shares_memory(X, built_data) From b40580661df60f514c31ba1dc724aed413e67a82 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Mon, 30 May 2022 15:57:14 -0500 Subject: [PATCH 15/19] lint --- src/c_api.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/c_api.cpp b/src/c_api.cpp index d9f474848822..2bfbcbdf4f42 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -2130,8 +2130,8 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle, } int LGBM_BoosterValidateFeatureNames(BoosterHandle handle, - const char** data_names, - int data_num_features) { + const char** data_names, + int data_num_features) { API_BEGIN(); int booster_num_features; size_t out_buffer_len; From efae6adc8bc20ac3f36f5a3827ecfc38f655bf83 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Mon, 30 May 2022 16:04:32 -0500 Subject: [PATCH 16/19] restore _data_from_pandas arguments --- python-package/lightgbm/basic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index e46708a6bd13..d67ae7404c08 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -3547,7 +3547,7 @@ def predict(self, data, start_iteration=0, num_iteration=None, ctypes.c_int(len(data_names)), ) ) - data = _data_from_pandas(data, self.feature_name(), None, self.pandas_categorical)[0] + data = _data_from_pandas(data, None, None, self.pandas_categorical)[0] predictor = self._to_predictor(deepcopy(kwargs)) if num_iteration is None: if start_iteration <= 0: From d89902153c35358bfbdbebc59af401bd744679da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Wed, 15 Jun 2022 14:49:00 -0500 Subject: [PATCH 17/19] Apply suggestions from code review Co-authored-by: Nikita Titov --- include/LightGBM/c_api.h | 6 +++--- src/c_api.cpp | 2 +- tests/python_package_test/test_sklearn.py | 2 +- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index ac30bf311f00..9d4b703c25b9 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -679,10 +679,10 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetFeatureNames(BoosterHandle handle, char** out_strs); /*! - * \brief Check that the feature names of the data match the ones used to train the booster + * \brief Check that the feature names of the data match the ones used to train the booster. * \param handle Handle of booster - * \param data_names array with the feature names in the data - * \param data_num_features number of features in the data + * \param data_names Array with the feature names in the data + * \param data_num_features Number of features in the data * \return 0 when succeed, -1 when failure happens */ LIGHTGBM_C_EXPORT int LGBM_BoosterValidateFeatureNames(BoosterHandle handle, diff --git a/src/c_api.cpp b/src/c_api.cpp index 2bfbcbdf4f42..3615faad8ffc 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -2137,7 +2137,7 @@ int LGBM_BoosterValidateFeatureNames(BoosterHandle handle, size_t out_buffer_len; LGBM_BoosterGetFeatureNames(handle, 0, &booster_num_features, 0, &out_buffer_len, nullptr); if (booster_num_features != data_num_features) { - Log::Fatal("Booster was trained on %d features, but got %d input features to predict.", booster_num_features, data_num_features); + Log::Fatal("Model was trained on %d features, but got %d input features to predict.", booster_num_features, data_num_features); } std::vector> tmp_names(booster_num_features); std::vector booster_names(booster_num_features); diff --git a/tests/python_package_test/test_sklearn.py b/tests/python_package_test/test_sklearn.py index b63ea5619660..259a91ee7803 100644 --- a/tests/python_package_test/test_sklearn.py +++ b/tests/python_package_test/test_sklearn.py @@ -1312,7 +1312,7 @@ def test_validate_features(task): model.fit(df, y, group=g) else: model.fit(df, y) - assert model.booster_.feature_name() == features + assert model.feature_name_ == features # try to predict with a different feature df2 = df.rename(columns={'x2': 'z'}) From 774a7159c25f130753b6c951142682184992b350 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Wed, 15 Jun 2022 14:53:45 -0500 Subject: [PATCH 18/19] move data conversion to Predictor.predict --- python-package/lightgbm/basic.py | 35 +++++++++++++++++--------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index d67ae7404c08..91fe5ee4d8cd 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -751,7 +751,7 @@ def __getstate__(self): return this def predict(self, data, start_iteration=0, num_iteration=-1, - raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False): + raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False, validate_features=False): """Predict logic. Parameters @@ -772,6 +772,9 @@ def predict(self, data, start_iteration=0, num_iteration=-1, data_has_header : bool, optional (default=False) Whether data has header. Used only for txt data. + validate_features : bool, optional (default=False) + If True, ensure that the features used to predict match the ones used to train. + Used only if data is pandas DataFrame. Returns ------- @@ -779,6 +782,20 @@ def predict(self, data, start_iteration=0, num_iteration=-1, Prediction result. Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``). """ + if isinstance(data, Dataset): + raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead") + elif isinstance(data, pd_DataFrame) and validate_features: + data_names = [str(x) for x in data.columns] + ptr_names = (ctypes.c_char_p * len(data_names))() + ptr_names[:] = [x.encode('utf-8') for x in data_names] + _safe_call( + _LIB.LGBM_BoosterValidateFeatureNames( + self.handle, + ptr_names, + ctypes.c_int(len(data_names)), + ) + ) + data = _data_from_pandas(data, None, None, self.pandas_categorical)[0] predict_type = C_API_PREDICT_NORMAL if raw_score: predict_type = C_API_PREDICT_RAW_SCORE @@ -3534,20 +3551,6 @@ def predict(self, data, start_iteration=0, num_iteration=None, Prediction result. Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``). """ - if isinstance(data, Dataset): - raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead") - elif isinstance(data, pd_DataFrame) and validate_features: - data_names = [str(x) for x in data.columns] - ptr_names = (ctypes.c_char_p * len(data_names))() - ptr_names[:] = [x.encode('utf-8') for x in data_names] - _safe_call( - _LIB.LGBM_BoosterValidateFeatureNames( - self.handle, - ptr_names, - ctypes.c_int(len(data_names)), - ) - ) - data = _data_from_pandas(data, None, None, self.pandas_categorical)[0] predictor = self._to_predictor(deepcopy(kwargs)) if num_iteration is None: if start_iteration <= 0: @@ -3556,7 +3559,7 @@ def predict(self, data, start_iteration=0, num_iteration=None, num_iteration = -1 return predictor.predict(data, start_iteration, num_iteration, raw_score, pred_leaf, pred_contrib, - data_has_header) + data_has_header, validate_features) def refit( self, From 3ff99ff0792b877f26d0d6de25c56cdeb3de2e0d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jos=C3=A9=20Morales?= Date: Mon, 20 Jun 2022 10:14:01 -0500 Subject: [PATCH 19/19] use Vector2Ptr --- src/c_api.cpp | 8 ++------ 1 file changed, 2 insertions(+), 6 deletions(-) diff --git a/src/c_api.cpp b/src/c_api.cpp index 3615faad8ffc..ebac89fe71e0 100644 --- a/src/c_api.cpp +++ b/src/c_api.cpp @@ -2139,12 +2139,8 @@ int LGBM_BoosterValidateFeatureNames(BoosterHandle handle, if (booster_num_features != data_num_features) { Log::Fatal("Model was trained on %d features, but got %d input features to predict.", booster_num_features, data_num_features); } - std::vector> tmp_names(booster_num_features); - std::vector booster_names(booster_num_features); - for (int i = 0; i < booster_num_features; ++i) { - tmp_names[i].resize(out_buffer_len); - booster_names[i] = tmp_names[i].data(); - } + std::vector> tmp_names(booster_num_features, std::vector(out_buffer_len)); + std::vector booster_names = Vector2Ptr(&tmp_names); LGBM_BoosterGetFeatureNames(handle, data_num_features, &booster_num_features, out_buffer_len, &out_buffer_len, booster_names.data()); for (int i = 0; i < booster_num_features; ++i) { if (strcmp(data_names[i], booster_names[i]) != 0) {