Skip to content

Commit

Permalink
[python-package] check feature names in predict with dataframe (fixes #…
Browse files Browse the repository at this point in the history
…812) (#4909)

* check feature names and order in predict with dataframe

* slice df in predict to remove the target

* scramble features

* handle int column names

* only change column order when needed

* include validate_features param in booster and sklearn estimators

* document validate_features argument

* use all_close in preds checks and check for assertion error to compare different arrays

* perform remapping and checks in cpp

* remove extra logs

* fixes

* revert cpp

* proposal

* remove extra arg

* lint

* restore _data_from_pandas arguments

* Apply suggestions from code review

Co-authored-by: Nikita Titov <[email protected]>

* move data conversion to Predictor.predict

* use Vector2Ptr

Co-authored-by: Nikita Titov <[email protected]>
  • Loading branch information
jmoralez and StrikerRUS authored Jun 27, 2022
1 parent 521fe8d commit bdb02e0
Show file tree
Hide file tree
Showing 6 changed files with 127 additions and 19 deletions.
11 changes: 11 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,17 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetFeatureNames(BoosterHandle handle,
size_t* out_buffer_len,
char** out_strs);

/*!
* \brief Check that the feature names of the data match the ones used to train the booster.
* \param handle Handle of booster
* \param data_names Array with the feature names in the data
* \param data_num_features Number of features in the data
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterValidateFeatureNames(BoosterHandle handle,
const char** data_names,
int data_num_features);

/*!
* \brief Get number of features.
* \param handle Handle of booster
Expand Down
24 changes: 21 additions & 3 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -757,7 +757,7 @@ def __getstate__(self):
return this

def predict(self, data, start_iteration=0, num_iteration=-1,
raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False):
raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False, validate_features=False):
"""Predict logic.
Parameters
Expand All @@ -778,6 +778,9 @@ def predict(self, data, start_iteration=0, num_iteration=-1,
data_has_header : bool, optional (default=False)
Whether data has header.
Used only for txt data.
validate_features : bool, optional (default=False)
If True, ensure that the features used to predict match the ones used to train.
Used only if data is pandas DataFrame.
Returns
-------
Expand All @@ -787,6 +790,17 @@ def predict(self, data, start_iteration=0, num_iteration=-1,
"""
if isinstance(data, Dataset):
raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead")
elif isinstance(data, pd_DataFrame) and validate_features:
data_names = [str(x) for x in data.columns]
ptr_names = (ctypes.c_char_p * len(data_names))()
ptr_names[:] = [x.encode('utf-8') for x in data_names]
_safe_call(
_LIB.LGBM_BoosterValidateFeatureNames(
self.handle,
ptr_names,
ctypes.c_int(len(data_names)),
)
)
data = _data_from_pandas(data, None, None, self.pandas_categorical)[0]
predict_type = C_API_PREDICT_NORMAL
if raw_score:
Expand Down Expand Up @@ -3501,7 +3515,8 @@ def dump_model(self, num_iteration=None, start_iteration=0, importance_type='spl

def predict(self, data, start_iteration=0, num_iteration=None,
raw_score=False, pred_leaf=False, pred_contrib=False,
data_has_header=False, **kwargs):
data_has_header=False, validate_features=False,
**kwargs):
"""Make a prediction.
Parameters
Expand Down Expand Up @@ -3535,6 +3550,9 @@ def predict(self, data, start_iteration=0, num_iteration=None,
data_has_header : bool, optional (default=False)
Whether the data has header.
Used only if data is str.
validate_features : bool, optional (default=False)
If True, ensure that the features used to predict match the ones used to train.
Used only if data is pandas DataFrame.
**kwargs
Other parameters for the prediction.
Expand All @@ -3552,7 +3570,7 @@ def predict(self, data, start_iteration=0, num_iteration=None,
num_iteration = -1
return predictor.predict(data, start_iteration, num_iteration,
raw_score, pred_leaf, pred_contrib,
data_has_header)
data_has_header, validate_features)

def refit(
self,
Expand Down
18 changes: 12 additions & 6 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,9 @@ def __call__(self, preds: np.ndarray, dataset: Dataset) -> Tuple[str, float, boo
Note that unlike the shap package, with ``pred_contrib`` we return a matrix with an extra
column, where the last column is the expected value.
validate_features : bool, optional (default=False)
If True, ensure that the features used to predict match the ones used to train.
Used only if data is pandas DataFrame.
**kwargs
Other parameters for the prediction.
Expand Down Expand Up @@ -820,7 +823,7 @@ def _get_meta_data(collection, name, i):
) + "\n\n" + _lgbmmodel_doc_custom_eval_note

def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
pred_leaf=False, pred_contrib=False, **kwargs):
pred_leaf=False, pred_contrib=False, validate_features=False, **kwargs):
"""Docstring is set after definition, using a template."""
if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError("Estimator not fitted, call fit before exploiting the model.")
Expand Down Expand Up @@ -853,7 +856,8 @@ def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
predict_params["num_threads"] = self._process_n_jobs(predict_params["num_threads"])

return self._Booster.predict(X, raw_score=raw_score, start_iteration=start_iteration, num_iteration=num_iteration,
pred_leaf=pred_leaf, pred_contrib=pred_contrib, **predict_params)
pred_leaf=pred_leaf, pred_contrib=pred_contrib, validate_features=validate_features,
**predict_params)

predict.__doc__ = _lgbmmodel_doc_predict.format(
description="Return the predicted value for each sample.",
Expand Down Expand Up @@ -1087,10 +1091,12 @@ def fit(
+ _base_doc[_base_doc.find('eval_metric :'):])

def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
pred_leaf=False, pred_contrib=False, **kwargs):
pred_leaf=False, pred_contrib=False, validate_features=False,
**kwargs):
"""Docstring is inherited from the LGBMModel."""
result = self.predict_proba(X, raw_score, start_iteration, num_iteration,
pred_leaf, pred_contrib, **kwargs)
pred_leaf, pred_contrib, validate_features,
**kwargs)
if callable(self._objective) or raw_score or pred_leaf or pred_contrib:
return result
else:
Expand All @@ -1100,9 +1106,9 @@ def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
predict.__doc__ = LGBMModel.predict.__doc__

def predict_proba(self, X, raw_score=False, start_iteration=0, num_iteration=None,
pred_leaf=False, pred_contrib=False, **kwargs):
pred_leaf=False, pred_contrib=False, validate_features=False, **kwargs):
"""Docstring is set after definition, using a template."""
result = super().predict(X, raw_score, start_iteration, num_iteration, pred_leaf, pred_contrib, **kwargs)
result = super().predict(X, raw_score, start_iteration, num_iteration, pred_leaf, pred_contrib, validate_features, **kwargs)
if callable(self._objective) and not (raw_score or pred_leaf or pred_contrib):
_log_warning("Cannot compute class probabilities or labels "
"due to the usage of customized objective function.\n"
Expand Down
21 changes: 21 additions & 0 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2129,6 +2129,27 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle,
API_END();
}

int LGBM_BoosterValidateFeatureNames(BoosterHandle handle,
const char** data_names,
int data_num_features) {
API_BEGIN();
int booster_num_features;
size_t out_buffer_len;
LGBM_BoosterGetFeatureNames(handle, 0, &booster_num_features, 0, &out_buffer_len, nullptr);
if (booster_num_features != data_num_features) {
Log::Fatal("Model was trained on %d features, but got %d input features to predict.", booster_num_features, data_num_features);
}
std::vector<std::vector<char>> tmp_names(booster_num_features, std::vector<char>(out_buffer_len));
std::vector<char*> booster_names = Vector2Ptr(&tmp_names);
LGBM_BoosterGetFeatureNames(handle, data_num_features, &booster_num_features, out_buffer_len, &out_buffer_len, booster_names.data());
for (int i = 0; i < booster_num_features; ++i) {
if (strcmp(data_names[i], booster_names[i]) != 0) {
Log::Fatal("Expected '%s' at position %d but found '%s'", booster_names[i], i, data_names[i]);
}
}
API_END();
}

int LGBM_BoosterPredictForMat(BoosterHandle handle,
const void* data,
int data_type,
Expand Down
19 changes: 19 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sklearn.model_selection import GroupKFold, TimeSeriesSplit, train_test_split

import lightgbm as lgb
from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame

from .utils import (dummy_obj, load_boston, load_breast_cancer, load_digits, load_iris, logistic_sigmoid,
make_synthetic_regression, mse_obj, sklearn_multiclass_custom_objective, softmax)
Expand Down Expand Up @@ -3623,3 +3624,21 @@ def test_cegb_split_buffer_clean():
predicts = model.predict(test_data)
rmse = np.sqrt(mean_squared_error(test_y, predicts))
assert rmse < 10.0


@pytest.mark.skipif(not PANDAS_INSTALLED, reason='pandas is not installed')
def test_validate_features():
X, y = make_synthetic_regression()
features = ['x1', 'x2', 'x3', 'x4']
df = pd_DataFrame(X, columns=features)
ds = lgb.Dataset(df, y)
bst = lgb.train({'num_leaves': 15, 'verbose': -1}, ds, num_boost_round=10)
assert bst.feature_name() == features

# try to predict with a different feature
df2 = df.rename(columns={'x3': 'z'})
with pytest.raises(lgb.basic.LightGBMError, match="Expected 'x3' at position 2 but found 'z'"):
bst.predict(df2, validate_features=True)

# check that disabling the check doesn't raise the error
bst.predict(df2, validate_features=False)
53 changes: 43 additions & 10 deletions tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,30 @@
from sklearn.utils.validation import check_is_fitted

import lightgbm as lgb
from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame

from .utils import (load_boston, load_breast_cancer, load_digits, load_iris, load_linnerud, make_ranking,
make_synthetic_regression, sklearn_multiclass_custom_objective, softmax)

decreasing_generator = itertools.count(0, -1)
task_to_model_factory = {
'ranking': lgb.LGBMRanker,
'classification': lgb.LGBMClassifier,
'regression': lgb.LGBMRegressor,
}


def _create_data(task):
if task == 'ranking':
X, y, g = make_ranking(n_features=4)
g = np.bincount(g)
elif task == 'classification':
X, y = load_iris(return_X_y=True)
g = None
elif task == 'regression':
X, y = make_synthetic_regression()
g = None
return X, y, g


class UnpicklableCallback:
Expand Down Expand Up @@ -1244,23 +1263,15 @@ def test_sklearn_integration(estimator, check):
@pytest.mark.parametrize('task', ['classification', 'ranking', 'regression'])
def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(task):
pd = pytest.importorskip("pandas")
if task == 'ranking':
X, y, g = make_ranking()
g = np.bincount(g)
model_factory = lgb.LGBMRanker
elif task == 'classification':
X, y = load_iris(return_X_y=True)
model_factory = lgb.LGBMClassifier
elif task == 'regression':
X, y = make_synthetic_regression()
model_factory = lgb.LGBMRegressor
X, y, g = _create_data(task)
X = pd.DataFrame(X)
y_col_array = y.reshape(-1, 1)
params = {
'n_estimators': 1,
'num_leaves': 3,
'random_state': 0
}
model_factory = task_to_model_factory[task]
with pytest.warns(UserWarning, match='column-vector'):
if task == 'ranking':
model_1d = model_factory(**params).fit(X, y, group=g)
Expand Down Expand Up @@ -1315,3 +1326,25 @@ def test_default_n_jobs(tmp_path):
with open(tmp_path / "model.txt", "r") as f:
model_txt = f.read()
assert bool(re.search(rf"\[num_threads: {n_cores}\]", model_txt))


@pytest.mark.skipif(not PANDAS_INSTALLED, reason='pandas is not installed')
@pytest.mark.parametrize('task', ['classification', 'ranking', 'regression'])
def test_validate_features(task):
X, y, g = _create_data(task)
features = ['x1', 'x2', 'x3', 'x4']
df = pd_DataFrame(X, columns=features)
model = task_to_model_factory[task](n_estimators=10, num_leaves=15, verbose=-1)
if task == 'ranking':
model.fit(df, y, group=g)
else:
model.fit(df, y)
assert model.feature_name_ == features

# try to predict with a different feature
df2 = df.rename(columns={'x2': 'z'})
with pytest.raises(lgb.basic.LightGBMError, match="Expected 'x2' at position 1 but found 'z'"):
model.predict(df2, validate_features=True)

# check that disabling the check doesn't raise the error
model.predict(df2, validate_features=False)

0 comments on commit bdb02e0

Please sign in to comment.