Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] check feature names in predict with dataframe (fixes #812) #4909

Merged
merged 25 commits into from
Jun 27, 2022
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 3 additions & 2 deletions examples/python-guide/advanced_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,8 @@
# load model to predict
bst = lgb.Booster(model_file='model.txt')
# can only predict with the best iteration (or the saving iteration)
y_pred = bst.predict(X_test)
# disable validating feature names since we know they're correct
y_pred = bst.predict(X_test, validate_features=False)
# eval with loaded model
rmse_loaded_model = mean_squared_error(y_test, y_pred) ** 0.5
print(f"The RMSE of loaded model's prediction is: {rmse_loaded_model}")
Expand All @@ -94,7 +95,7 @@
with open('model.pkl', 'rb') as fin:
pkl_bst = pickle.load(fin)
# can predict with any iteration when loaded in pickle way
y_pred = pkl_bst.predict(X_test, num_iteration=7)
y_pred = pkl_bst.predict(X_test, num_iteration=7, validate_features=False)
# eval with loaded model
rmse_pickled_model = mean_squared_error(y_test, y_pred) ** 0.5
print(f"The RMSE of pickled model's prediction is: {rmse_pickled_model}")
Expand Down
25 changes: 20 additions & 5 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,12 +512,23 @@ def _get_bad_pandas_dtypes(dtypes):
return bad_indices


def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical):
def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical, validate_features=False):
jmoralez marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(data, pd_DataFrame):
if len(data.shape) != 2 or data.shape[0] < 1:
raise ValueError('Input data must be 2 dimensional and non empty.')
if feature_name == 'auto' or feature_name is None:
data = data.rename(columns=str)
elif isinstance(feature_name, list) and validate_features:
df_features = [str(x) for x in data.columns]
missing_features = set(feature_name) - set(df_features)
if missing_features:
raise ValueError(
f"The following features are missing: {missing_features}.\n"
"If you're sure the features are correct you can disable this check by setting validate_features=False"
)
sort_idxs = [df_features.index(feature) for feature in feature_name]
if not all(x == i for i, x in enumerate(sort_idxs)):
data = data.iloc[:, sort_idxs] # ensure column order
cat_cols = [col for col, dtype in zip(data.columns, data.dtypes) if isinstance(dtype, pd_CategoricalDtype)]
cat_cols_not_ordered = [col for col in cat_cols if not data[col].cat.ordered]
if pandas_categorical is None: # train dataset
Expand Down Expand Up @@ -767,9 +778,6 @@ def predict(self, data, start_iteration=0, num_iteration=-1,
Prediction result.
Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``).
"""
if isinstance(data, Dataset):
raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead")
data = _data_from_pandas(data, None, None, self.pandas_categorical)[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why was this moved from Predictor code? Now Predictor cannot accept pandas DataFrame which means, for example, refit() method cannot accept DataFrames anymore:

leaf_preds = predictor.predict(data, -1, pred_leaf=True)

and DataFrames cannot be used together with init_model argument during constructing Dataset:

init_score = predictor.predict(data,
raw_score=True,
data_has_header=data_has_header)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this comment, and will just add that it would be useful to have unit tests (in a separate PR) for them, so such regressions could be caught automatically in the future.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Addressed in 774a715

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

Should we add validate_features argument to those methods as well?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can work on adding to refit as well.

predict_type = C_API_PREDICT_NORMAL
if raw_score:
predict_type = C_API_PREDICT_RAW_SCORE
Expand Down Expand Up @@ -3456,7 +3464,8 @@ def dump_model(self, num_iteration=None, start_iteration=0, importance_type='spl

def predict(self, data, start_iteration=0, num_iteration=None,
raw_score=False, pred_leaf=False, pred_contrib=False,
data_has_header=False, is_reshape=True, **kwargs):
data_has_header=False, is_reshape=True, validate_features=True,
**kwargs):
"""Make a prediction.

Parameters
Expand Down Expand Up @@ -3492,6 +3501,9 @@ def predict(self, data, start_iteration=0, num_iteration=None,
Used only if data is str.
is_reshape : bool, optional (default=True)
If True, result is reshaped to [nrow, ncol].
validate_features : bool, optional (default=True)
If True, ensure that the features used to predict match the ones used to train.
Used only if data is pandas DataFrame.
**kwargs
Other parameters for the prediction.

Expand All @@ -3501,6 +3513,9 @@ def predict(self, data, start_iteration=0, num_iteration=None,
Prediction result.
Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``).
"""
if isinstance(data, Dataset):
raise TypeError("Cannot use Dataset instance for prediction, please use raw data instead")
data = _data_from_pandas(data, self.feature_name(), None, self.pandas_categorical, validate_features=validate_features)[0]
predictor = self._to_predictor(deepcopy(kwargs))
if num_iteration is None:
if start_iteration <= 0:
Expand Down
18 changes: 12 additions & 6 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,9 @@ def __call__(self, preds, dataset):
Note that unlike the shap package, with ``pred_contrib`` we return a matrix with an extra
column, where the last column is the expected value.

validate_features : bool, optional (default=True)
If True, ensure that the features used to predict match the ones used to train.
Used only if data is pandas DataFrame.
**kwargs
Other parameters for the prediction.

Expand Down Expand Up @@ -802,7 +805,7 @@ def _get_meta_data(collection, name, i):
) + "\n\n" + _lgbmmodel_doc_custom_eval_note

def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
pred_leaf=False, pred_contrib=False, **kwargs):
pred_leaf=False, pred_contrib=False, validate_features=True, **kwargs):
"""Docstring is set after definition, using a template."""
if not self.__sklearn_is_fitted__():
raise LGBMNotFittedError("Estimator not fitted, call fit before exploiting the model.")
Expand All @@ -829,7 +832,8 @@ def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
predict_params.pop(alias, None)
predict_params.update(kwargs)
return self._Booster.predict(X, raw_score=raw_score, start_iteration=start_iteration, num_iteration=num_iteration,
pred_leaf=pred_leaf, pred_contrib=pred_contrib, **predict_params)
pred_leaf=pred_leaf, pred_contrib=pred_contrib, validate_features=validate_features,
**predict_params)

predict.__doc__ = _lgbmmodel_doc_predict.format(
description="Return the predicted value for each sample.",
Expand Down Expand Up @@ -1063,10 +1067,12 @@ def fit(
+ _base_doc[_base_doc.find('eval_metric :'):])

def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
pred_leaf=False, pred_contrib=False, **kwargs):
pred_leaf=False, pred_contrib=False, validate_features=True,
**kwargs):
"""Docstring is inherited from the LGBMModel."""
result = self.predict_proba(X, raw_score, start_iteration, num_iteration,
pred_leaf, pred_contrib, **kwargs)
pred_leaf, pred_contrib, validate_features,
**kwargs)
if callable(self._objective) or raw_score or pred_leaf or pred_contrib:
return result
else:
Expand All @@ -1076,9 +1082,9 @@ def predict(self, X, raw_score=False, start_iteration=0, num_iteration=None,
predict.__doc__ = LGBMModel.predict.__doc__

def predict_proba(self, X, raw_score=False, start_iteration=0, num_iteration=None,
pred_leaf=False, pred_contrib=False, **kwargs):
pred_leaf=False, pred_contrib=False, validate_features=True, **kwargs):
"""Docstring is set after definition, using a template."""
result = super().predict(X, raw_score, start_iteration, num_iteration, pred_leaf, pred_contrib, **kwargs)
result = super().predict(X, raw_score, start_iteration, num_iteration, pred_leaf, pred_contrib, validate_features, **kwargs)
if callable(self._objective) and not (raw_score or pred_leaf or pred_contrib):
_log_warning("Cannot compute class probabilities or labels "
"due to the usage of customized objective function.\n"
Expand Down
30 changes: 29 additions & 1 deletion tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import lightgbm as lgb
from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame, pd_Series

from .utils import load_breast_cancer
from .utils import load_breast_cancer, make_synthetic_regression


def test_basic(tmp_path):
Expand Down Expand Up @@ -609,3 +609,31 @@ def test_custom_objective_safety():
good_bst_multi.update(fobj=_good_gradients)
with pytest.raises(ValueError, match=re.escape(f"number of models per one iteration ({nclass})")):
bad_bst_multi.update(fobj=_bad_gradients)


@pytest.mark.skipif(not PANDAS_INSTALLED, reason='pandas is not installed')
def test_validate_features():
X, y = make_synthetic_regression()
features = ['x1', 'x2', 'x3', 'x4']
df = pd_DataFrame(X, columns=features)
ds = lgb.Dataset(df, y)
bst = lgb.train({'num_leaves': 15, 'verbose': -1}, ds, num_boost_round=10)
assert bst.feature_name() == features

# try to predict with a different feature
df2 = df.rename(columns={'x1': 'z'})
with pytest.raises(ValueError, match="The following features are missing: {'x1'}"):
bst.predict(df2)

# check that disabling the check doesn't raise the error
bst.predict(df2, validate_features=False)

# predict with the features out of order
preds_sorted_features = bst.predict(df[features])
scrambled_features = ['x3', 'x1', 'x4', 'x2']
preds_scrambled_features = bst.predict(df[scrambled_features])
np.testing.assert_allclose(preds_sorted_features, preds_scrambled_features)

# check that disabling the check doesn't raise an error and produces incorrect predictions
preds_scrambled_features_no_check = bst.predict(df[scrambled_features], validate_features=False)
np.testing.assert_raises(AssertionError, np.testing.assert_allclose, preds_sorted_features, preds_scrambled_features_no_check)
63 changes: 53 additions & 10 deletions tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from sklearn.utils.validation import check_is_fitted

import lightgbm as lgb
from lightgbm.compat import PANDAS_INSTALLED, pd_DataFrame

from .utils import (load_boston, load_breast_cancer, load_digits, load_iris, load_linnerud, make_ranking,
make_synthetic_regression)
Expand All @@ -31,6 +32,24 @@
from sklearn.utils.estimator_checks import parametrize_with_checks

decreasing_generator = itertools.count(0, -1)
task_to_model_factory = {
'ranking': lgb.LGBMRanker,
'classification': lgb.LGBMClassifier,
'regression': lgb.LGBMRegressor,
}


def _create_data(task):
if task == 'ranking':
X, y, g = make_ranking(n_features=4)
g = np.bincount(g)
elif task == 'classification':
X, y = load_iris(return_X_y=True)
g = None
elif task == 'regression':
X, y = make_synthetic_regression()
g = None
return X, y, g


class UnpicklableCallback:
Expand Down Expand Up @@ -1325,23 +1344,15 @@ def test_parameters_default_constructible(estimator):
@pytest.mark.parametrize('task', ['classification', 'ranking', 'regression'])
def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(task):
pd = pytest.importorskip("pandas")
if task == 'ranking':
X, y, g = make_ranking()
g = np.bincount(g)
model_factory = lgb.LGBMRanker
elif task == 'classification':
X, y = load_iris(return_X_y=True)
model_factory = lgb.LGBMClassifier
elif task == 'regression':
X, y = make_synthetic_regression()
model_factory = lgb.LGBMRegressor
X, y, g = _create_data(task)
X = pd.DataFrame(X)
y_col_array = y.reshape(-1, 1)
params = {
'n_estimators': 1,
'num_leaves': 3,
'random_state': 0
}
model_factory = task_to_model_factory[task]
with pytest.warns(UserWarning, match='column-vector'):
if task == 'ranking':
model_1d = model_factory(**params).fit(X, y, group=g)
Expand All @@ -1353,3 +1364,35 @@ def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(task
preds_1d = model_1d.predict(X)
preds_2d = model_2d.predict(X)
np.testing.assert_array_equal(preds_1d, preds_2d)


@pytest.mark.skipif(not PANDAS_INSTALLED, reason='pandas is not installed')
@pytest.mark.parametrize('task', ['classification', 'ranking', 'regression'])
def test_validate_features(task):
X, y, g = _create_data(task)
features = ['x1', 'x2', 'x3', 'x4']
df = pd_DataFrame(X, columns=features)
model = task_to_model_factory[task](n_estimators=10, num_leaves=15, verbose=-1)
if task == 'ranking':
model.fit(df, y, group=g)
else:
model.fit(df, y)
assert model.booster_.feature_name() == features
jmoralez marked this conversation as resolved.
Show resolved Hide resolved

# try to predict with a different feature
df2 = df.rename(columns={'x1': 'z'})
with pytest.raises(ValueError, match="The following features are missing: {'x1'}"):
model.predict(df2)

# check that disabling the check doesn't raise the error
model.predict(df2, validate_features=False)

# predict with the features out of order
preds_sorted_features = model.predict(df[features])
scrambled_features = ['x3', 'x1', 'x4', 'x2']
preds_scrambled_features = model.predict(df[scrambled_features])
np.testing.assert_allclose(preds_sorted_features, preds_scrambled_features)

# check that disabling the check doesn't raise an error and produces incorrect predictions
preds_scrambled_features_no_check = model.predict(df[scrambled_features], validate_features=False)
np.testing.assert_raises(AssertionError, np.testing.assert_allclose, preds_sorted_features, preds_scrambled_features_no_check)