Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] check feature names in predict with dataframe (fixes #812) #4909

Merged
merged 25 commits into from
Jun 27, 2022
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/LightGBM/c_api.h
Original file line number Diff line number Diff line change
Expand Up @@ -678,6 +678,17 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterGetFeatureNames(BoosterHandle handle,
size_t* out_buffer_len,
char** out_strs);

/*!
* \brief Check that the feature names of the data match the ones used to train the booster.
* \param handle Handle of booster
* \param data_names Array with the feature names in the data
* \param data_num_features Number of features in the data
* \return 0 when succeed, -1 when failure happens
*/
LIGHTGBM_C_EXPORT int LGBM_BoosterValidateFeatureNames(BoosterHandle handle,
const char** data_names,
int data_num_features);

/*!
* \brief Get number of features.
* \param handle Handle of booster
Expand Down
24 changes: 21 additions & 3 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,7 +751,7 @@ def __getstate__(self):
return this

def predict(self, data, start_iteration=0, num_iteration=-1,
raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False):
raw_score=False, pred_leaf=False, pred_contrib=False, data_has_header=False, validate_features=False):
"""Predict logic.

Parameters
Expand All @@ -772,6 +772,9 @@ def predict(self, data, start_iteration=0, num_iteration=-1,
data_has_header : bool, optional (default=False)
Whether data has header.
Used only for txt data.
validate_features : bool, optional (default=False)
If True, ensure that the features used to predict match the ones used to train.
Used only if data is pandas DataFrame.

Returns
-------
Expand All @@ -781,6 +784,17 @@ def predict(self, data, start_iteration=0, num_iteration=-1,
"""
if isinstance(data, Dataset):
raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead")
elif isinstance(data, pd_DataFrame) and validate_features:
data_names = [str(x) for x in data.columns]
ptr_names = (ctypes.c_char_p * len(data_names))()
ptr_names[:] = [x.encode('utf-8') for x in data_names]
_safe_call(
_LIB.LGBM_BoosterValidateFeatureNames(
self.handle,
ptr_names,
ctypes.c_int(len(data_names)),
)
)
data = _data_from_pandas(data, None, None, self.pandas_categorical)[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this moved from Predictor code? Now Predictor cannot accept pandas DataFrame which means, for example, refit() method cannot accept DataFrames anymore:

leaf_preds = predictor.predict(data, -1, pred_leaf=True)

and DataFrames cannot be used together with init_model argument during constructing Dataset:

init_score = predictor.predict(data,
raw_score=True,
data_has_header=data_has_header)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this comment, and will just add that it would be useful to have unit tests (in a separate PR) for them, so such regressions could be caught automatically in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 774a715

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Should we add validate_features argument to those methods as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can work on adding to refit as well.

predict_type = C_API_PREDICT_NORMAL
if raw_score:
Expand Down Expand Up @@ -3489,7 +3503,8 @@ def dump_model(self, num_iteration=None, start_iteration=0, importance_type='spl

def predict(self, data, start_iteration=0, num_iteration=None,
raw_score=False, pred_leaf=False, pred_contrib=False,
data_has_header=False, **kwargs):
data_has_header=False, validate_features=False,
**kwargs):
"""Make a prediction.

Parameters
Expand Down Expand Up @@ -3523,6 +3538,9 @@ def predict(self, data, start_iteration=0, num_iteration=None,
data_has_header : bool, optional (default=False)
Whether the data has header.
Used only if data is str.
validate_features : bool, optional (default=False)
If True, ensure that the features used to predict match the ones used to train.
Used only if data is pandas DataFrame.
**kwargs
Other parameters for the prediction.

Expand All @@ -3540,7 +3558,7 @@ def predict(self, data, start_iteration=0, num_iteration=None,
num_iteration = -1
return predictor.predict(data, start_iteration, num_iteration,
raw_score, pred_leaf, pred_contrib,
data_has_header)
data_has_header, validate_features)

def refit(
self,
Expand Down
18 changes: 12 additions & 6 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,9 @@ def __call__(self, preds: np.ndarray, dataset: Dataset) -> Tuple[str, float, boo
Note that unlike the shap package, with ``pred_contrib`` we return a matrix with an extra
column, where the last column is the expected value.

validate_features : bool, optional (default=False)
If True, ensure that the features used to predict match the ones used to train.
Used only if data is pandas DataFrame.
**kwargs
Other parameters for the prediction.

Expand Down Expand Up @@ -784,7 +787,7 @@ def _get_meta_data(collection, name, i):
) + "\n\n" + _lgbmmodel_doc_custom_eval_note

def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
pred_leaf=False, pred_contrib=False, **kwargs):
pred_leaf=False, pred_contrib=False, validate_features=False, **kwargs):
"""Docstring is set after definition, using a template."""
if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError("Estimator not fitted, call fit before exploiting the model.")
Expand All @@ -811,7 +814,8 @@ def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
predict_params.pop(alias, None)
predict_params.update(kwargs)
return self._Booster.predict(X, raw_score=raw_score, start_iteration=start_iteration, num_iteration=num_iteration,
pred_leaf=pred_leaf, pred_contrib=pred_contrib, **predict_params)
pred_leaf=pred_leaf, pred_contrib=pred_contrib, validate_features=validate_features,
**predict_params)

predict.__doc__ = _lgbmmodel_doc_predict.format(
description="Return the predicted value for each sample.",
Expand Down Expand Up @@ -1045,10 +1049,12 @@ def fit(
+ _base_doc[_base_doc.find('eval_metric :'):])

def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
pred_leaf=False, pred_contrib=False, **kwargs):
pred_leaf=False, pred_contrib=False, validate_features=False,
**kwargs):
"""Docstring is inherited from the LGBMModel."""
result = self.predict_proba(X, raw_score, start_iteration, num_iteration,
pred_leaf, pred_contrib, **kwargs)
pred_leaf, pred_contrib, validate_features,
**kwargs)
if callable(self._objective) or raw_score or pred_leaf or pred_contrib:
return result
else:
Expand All @@ -1058,9 +1064,9 @@ def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
predict.__doc__ = LGBMModel.predict.__doc__

def predict_proba(self, X, raw_score=False, start_iteration=0, num_iteration=None,
pred_leaf=False, pred_contrib=False, **kwargs):
pred_leaf=False, pred_contrib=False, validate_features=False, **kwargs):
"""Docstring is set after definition, using a template."""
result = super().predict(X, raw_score, start_iteration, num_iteration, pred_leaf, pred_contrib, **kwargs)
result = super().predict(X, raw_score, start_iteration, num_iteration, pred_leaf, pred_contrib, validate_features, **kwargs)
if callable(self._objective) and not (raw_score or pred_leaf or pred_contrib):
_log_warning("Cannot compute class probabilities or labels "
"due to the usage of customized objective function.\n"
Expand Down
25 changes: 25 additions & 0 deletions src/c_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2129,6 +2129,31 @@ int LGBM_BoosterPredictForCSC(BoosterHandle handle,
API_END();
}

int LGBM_BoosterValidateFeatureNames(BoosterHandle handle,
const char** data_names,
int data_num_features) {
API_BEGIN();
int booster_num_features;
size_t out_buffer_len;
LGBM_BoosterGetFeatureNames(handle, 0, &booster_num_features, 0, &out_buffer_len, nullptr);
if (booster_num_features != data_num_features) {
Log::Fatal("Model was trained on %d features, but got %d input features to predict.", booster_num_features, data_num_features);
}
std::vector<std::vector<char>> tmp_names(booster_num_features);
jmoralez marked this conversation as resolved.
Show resolved Hide resolved
std::vector<char*> booster_names(booster_num_features);
for (int i = 0; i < booster_num_features; ++i) {
jmoralez marked this conversation as resolved.
Show resolved Hide resolved
tmp_names[i].resize(out_buffer_len);
booster_names[i] = tmp_names[i].data();
jmoralez marked this conversation as resolved.
Show resolved Hide resolved
}
LGBM_BoosterGetFeatureNames(handle, data_num_features, &booster_num_features, out_buffer_len, &out_buffer_len, booster_names.data());
for (int i = 0; i < booster_num_features; ++i) {
if (strcmp(data_names[i], booster_names[i]) != 0) {
Log::Fatal("Expected '%s' at position %d but found '%s'", booster_names[i], i, data_names[i]);
}
}
API_END();
}

int LGBM_BoosterPredictForMat(BoosterHandle handle,
const void* data,
int data_type,
Expand Down
19 changes: 19 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from sklearn.model_selection import GroupKFold, TimeSeriesSplit, train_test_split

import lightgbm as lgb
from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame

from .utils import (dummy_obj, load_boston, load_breast_cancer, load_digits, load_iris, logistic_sigmoid,
make_synthetic_regression, mse_obj, sklearn_multiclass_custom_objective, softmax)
Expand Down Expand Up @@ -3623,3 +3624,21 @@ def test_cegb_split_buffer_clean():
predicts = model.predict(test_data)
rmse = np.sqrt(mean_squared_error(test_y, predicts))
assert rmse < 10.0


@pytest.mark.skipif(not PANDAS_INSTALLED, reason='pandas is not installed')
def test_validate_features():
X, y = make_synthetic_regression()
features = ['x1', 'x2', 'x3', 'x4']
df = pd_DataFrame(X, columns=features)
ds = lgb.Dataset(df, y)
bst = lgb.train({'num_leaves': 15, 'verbose': -1}, ds, num_boost_round=10)
assert bst.feature_name() == features

# try to predict with a different feature
df2 = df.rename(columns={'x3': 'z'})
with pytest.raises(lgb.basic.LightGBMError, match="Expected 'x3' at position 2 but found 'z'"):
bst.predict(df2, validate_features=True)

# check that disabling the check doesn't raise the error
bst.predict(df2, validate_features=False)
53 changes: 43 additions & 10 deletions tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,30 @@
from sklearn.utils.validation import check_is_fitted

import lightgbm as lgb
from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame

from .utils import (load_boston, load_breast_cancer, load_digits, load_iris, load_linnerud, make_ranking,
make_synthetic_regression, sklearn_multiclass_custom_objective, softmax)

decreasing_generator = itertools.count(0, -1)
task_to_model_factory = {
'ranking': lgb.LGBMRanker,
'classification': lgb.LGBMClassifier,
'regression': lgb.LGBMRegressor,
}


def _create_data(task):
if task == 'ranking':
X, y, g = make_ranking(n_features=4)
g = np.bincount(g)
elif task == 'classification':
X, y = load_iris(return_X_y=True)
g = None
elif task == 'regression':
X, y = make_synthetic_regression()
g = None
return X, y, g


class UnpicklableCallback:
Expand Down Expand Up @@ -1243,23 +1262,15 @@ def test_sklearn_integration(estimator, check):
@pytest.mark.parametrize('task', ['classification', 'ranking', 'regression'])
def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(task):
pd = pytest.importorskip("pandas")
if task == 'ranking':
X, y, g = make_ranking()
g = np.bincount(g)
model_factory = lgb.LGBMRanker
elif task == 'classification':
X, y = load_iris(return_X_y=True)
model_factory = lgb.LGBMClassifier
elif task == 'regression':
X, y = make_synthetic_regression()
model_factory = lgb.LGBMRegressor
X, y, g = _create_data(task)
X = pd.DataFrame(X)
y_col_array = y.reshape(-1, 1)
params = {
'n_estimators': 1,
'num_leaves': 3,
'random_state': 0
}
model_factory = task_to_model_factory[task]
with pytest.warns(UserWarning, match='column-vector'):
if task == 'ranking':
model_1d = model_factory(**params).fit(X, y, group=g)
Expand Down Expand Up @@ -1288,3 +1299,25 @@ def test_multiclass_custom_objective():
np.testing.assert_allclose(builtin_obj_preds, custom_obj_preds, rtol=0.01)
assert not callable(builtin_obj_model.objective_)
assert callable(custom_obj_model.objective_)


@pytest.mark.skipif(not PANDAS_INSTALLED, reason='pandas is not installed')
@pytest.mark.parametrize('task', ['classification', 'ranking', 'regression'])
def test_validate_features(task):
X, y, g = _create_data(task)
features = ['x1', 'x2', 'x3', 'x4']
df = pd_DataFrame(X, columns=features)
model = task_to_model_factory[task](n_estimators=10, num_leaves=15, verbose=-1)
if task == 'ranking':
model.fit(df, y, group=g)
else:
model.fit(df, y)
assert model.feature_name_ == features

# try to predict with a different feature
df2 = df.rename(columns={'x2': 'z'})
with pytest.raises(lgb.basic.LightGBMError, match="Expected 'x2' at position 1 but found 'z'"):
model.predict(df2, validate_features=True)

# check that disabling the check doesn't raise the error
model.predict(df2, validate_features=False)