diff --git a/include/LightGBM/c_api.h b/include/LightGBM/c_api.h index e0df38eeb82e..ed639b7f298c 100644 --- a/include/LightGBM/c_api.h +++ b/include/LightGBM/c_api.h @@ -572,6 +572,9 @@ LIGHTGBM_C_EXPORT int LGBM_BoosterRefit(BoosterHandle handle, /*! * \brief Update the model by specifying gradient and Hessian directly * (this can be used to support customized loss functions). + * \note + * The length of the arrays referenced by ``grad`` and ``hess`` must be equal to + * ``num_class * num_train_data``, this is not verified by the library, the caller must ensure this. * \param handle Handle of booster * \param grad The first order derivative (gradient) statistics * \param hess The second order derivative (Hessian) statistics diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 64f1cb31edaa..c6aa2a69e88c 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -3031,7 +3031,15 @@ def __boost(self, grad, hess): assert grad.flags.c_contiguous assert hess.flags.c_contiguous if len(grad) != len(hess): - raise ValueError(f"Lengths of gradient({len(grad)}) and hessian({len(hess)}) don't match") + raise ValueError(f"Lengths of gradient ({len(grad)}) and Hessian ({len(hess)}) don't match") + num_train_data = self.train_set.num_data() + num_models = self.__num_class + if len(grad) != num_train_data * num_models: + raise ValueError( + f"Lengths of gradient ({len(grad)}) and Hessian ({len(hess)}) " + f"don't match training data length ({num_train_data}) * " + f"number of models per one iteration ({num_models})" + ) is_finished = ctypes.c_int(0) _safe_call(_LIB.LGBM_BoosterUpdateOneIterCustom( self.handle, diff --git a/tests/python_package_test/test_basic.py b/tests/python_package_test/test_basic.py index 18a8403eba85..99f581775199 100644 --- a/tests/python_package_test/test_basic.py +++ b/tests/python_package_test/test_basic.py @@ -1,6 +1,7 @@ # coding: utf-8 import filecmp import numbers +import re from pathlib import Path import numpy as np @@ -579,3 +580,32 @@ def test_param_aliases(): assert all(len(i) >= 1 for i in aliases.values()) assert all(k in v for k, v in aliases.items()) assert lgb.basic._ConfigAliases.get('config', 'task') == {'config', 'config_file', 'task', 'task_type'} + + +def _bad_gradients(preds, _): + return np.random.randn(len(preds) + 1), np.random.rand(len(preds) + 1) + + +def _good_gradients(preds, _): + return np.random.randn(len(preds)), np.random.rand(len(preds)) + + +def test_custom_objective_safety(): + nrows = 100 + X = np.random.randn(nrows, 5) + y_binary = np.arange(nrows) % 2 + classes = [0, 1, 2] + nclass = len(classes) + y_multiclass = np.arange(nrows) % nclass + ds_binary = lgb.Dataset(X, y_binary).construct() + ds_multiclass = lgb.Dataset(X, y_multiclass).construct() + bad_bst_binary = lgb.Booster({'objective': "none"}, ds_binary) + good_bst_binary = lgb.Booster({'objective': "none"}, ds_binary) + bad_bst_multi = lgb.Booster({'objective': "none", "num_class": nclass}, ds_multiclass) + good_bst_multi = lgb.Booster({'objective': "none", "num_class": nclass}, ds_multiclass) + good_bst_binary.update(fobj=_good_gradients) + with pytest.raises(ValueError, match=re.escape("number of models per one iteration (1)")): + bad_bst_binary.update(fobj=_bad_gradients) + good_bst_multi.update(fobj=_good_gradients) + with pytest.raises(ValueError, match=re.escape(f"number of models per one iteration ({nclass})")): + bad_bst_multi.update(fobj=_bad_gradients)