Skip to content

Commit

Permalink
[python-package] make _InnerPredictor construction stricter (#5961)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored Jul 14, 2023
1 parent ed28c84 commit 7d4d897
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 57 deletions.
147 changes: 95 additions & 52 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
'Sequence',
]

_BoosterHandle = ctypes.c_void_p
_DatasetHandle = ctypes.c_void_p
_ctypes_int_ptr = Union[
"ctypes._Pointer[ctypes.c_int32]",
Expand Down Expand Up @@ -837,52 +838,98 @@ class _InnerPredictor:

def __init__(
self,
model_file: Optional[Union[str, Path]] = None,
booster_handle: Optional[ctypes.c_void_p] = None,
pred_parameter: Optional[Dict[str, Any]] = None
booster_handle: _BoosterHandle,
pandas_categorical: Optional[List[List]],
pred_parameter: Dict[str, Any],
manage_handle: bool
):
"""Initialize the _InnerPredictor.
Parameters
----------
model_file : str, pathlib.Path or None, optional (default=None)
Path to the model file.
booster_handle : object or None, optional (default=None)
booster_handle : object
Handle of Booster.
pred_parameter: dict or None, optional (default=None)
pandas_categorical : list of list, or None
If provided, list of categories for ``pandas`` categorical columns.
Where the ``i``th element of the list contains the categories for the ``i``th categorical feature.
pred_parameter : dict
Other parameters for the prediction.
manage_handle : bool
If ``True``, free the corresponding Booster on the C++ side when this Python object is deleted.
"""
self._handle = ctypes.c_void_p()
self.__is_manage_handle = True
if model_file is not None:
"""Prediction task"""
out_num_iterations = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterCreateFromModelfile(
_c_str(str(model_file)),
ctypes.byref(out_num_iterations),
ctypes.byref(self._handle)))
out_num_class = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetNumClasses(
self._handle,
ctypes.byref(out_num_class)))
self.num_class = out_num_class.value
self.num_total_iteration = out_num_iterations.value
self.pandas_categorical = _load_pandas_categorical(file_name=model_file)
elif booster_handle is not None:
self.__is_manage_handle = False
self._handle = booster_handle
out_num_class = ctypes.c_int(0)
_safe_call(_LIB.LGBM_BoosterGetNumClasses(
self._handle = booster_handle
self.__is_manage_handle = manage_handle
self.pandas_categorical = pandas_categorical
self.pred_parameter = _param_dict_to_str(pred_parameter)

out_num_class = ctypes.c_int(0)
_safe_call(
_LIB.LGBM_BoosterGetNumClasses(
self._handle,
ctypes.byref(out_num_class)))
self.num_class = out_num_class.value
self.num_total_iteration = self.current_iteration()
self.pandas_categorical = None
else:
raise TypeError('Need model_file or booster_handle to create a predictor')
ctypes.byref(out_num_class)
)
)
self.num_class = out_num_class.value

pred_parameter = {} if pred_parameter is None else pred_parameter
self.pred_parameter = _param_dict_to_str(pred_parameter)
@classmethod
def from_booster(
cls,
booster: "Booster",
pred_parameter: Dict[str, Any]
) -> "_InnerPredictor":
"""Initialize an ``_InnerPredictor`` from a ``Booster``.
Parameters
----------
booster : Booster
Booster.
pred_parameter : dict
Other parameters for the prediction.
"""
out_cur_iter = ctypes.c_int(0)
_safe_call(
_LIB.LGBM_BoosterGetCurrentIteration(
booster._handle,
ctypes.byref(out_cur_iter)
)
)
return cls(
booster_handle=booster._handle,
pandas_categorical=booster.pandas_categorical,
pred_parameter=pred_parameter,
manage_handle=False
)

@classmethod
def from_model_file(
cls,
model_file: Union[str, Path],
pred_parameter: Dict[str, Any]
) -> "_InnerPredictor":
"""Initialize an ``_InnerPredictor`` from a text file containing a LightGBM model.
Parameters
----------
model_file : str or pathlib.Path
Path to the model file.
pred_parameter : dict
Other parameters for the prediction.
"""
booster_handle = ctypes.c_void_p()
out_num_iterations = ctypes.c_int(0)
_safe_call(
_LIB.LGBM_BoosterCreateFromModelfile(
_c_str(str(model_file)),
ctypes.byref(out_num_iterations),
ctypes.byref(booster_handle)
)
)
return cls(
booster_handle=booster_handle,
pandas_categorical=_load_pandas_categorical(file_name=model_file),
pred_parameter=pred_parameter,
manage_handle=True
)

def __del__(self) -> None:
try:
Expand Down Expand Up @@ -3046,7 +3093,7 @@ def __init__(
model_str : str or None, optional (default=None)
Model will be loaded from this string.
"""
self._handle = None
self._handle = ctypes.c_void_p()
self._network = False
self.__need_reload_eval_info = True
self._train_data_name = "training"
Expand Down Expand Up @@ -3097,7 +3144,6 @@ def __init__(
# copy the parameters from train_set
params.update(train_set.get_params())
params_str = _param_dict_to_str(params)
self._handle = ctypes.c_void_p()
_safe_call(_LIB.LGBM_BoosterCreate(
train_set._handle,
_c_str(params_str),
Expand Down Expand Up @@ -3126,7 +3172,6 @@ def __init__(
elif model_file is not None:
# Prediction task
out_num_iterations = ctypes.c_int(0)
self._handle = ctypes.c_void_p()
_safe_call(_LIB.LGBM_BoosterCreateFromModelfile(
_c_str(str(model_file)),
ctypes.byref(out_num_iterations),
Expand Down Expand Up @@ -3905,8 +3950,9 @@ def model_from_string(self, model_str: str) -> "Booster":
self : Booster
Loaded Booster object.
"""
if self._handle is not None:
_safe_call(_LIB.LGBM_BoosterFree(self._handle))
# ensure that existing Booster is freed before replacing it
# with a new one createdfrom file
_safe_call(_LIB.LGBM_BoosterFree(self._handle))
self._free_buffer()
self._handle = ctypes.c_void_p()
out_num_iterations = ctypes.c_int(0)
Expand Down Expand Up @@ -4106,7 +4152,10 @@ def predict(
Prediction result.
Can be sparse or a list of sparse objects (each element represents predictions for one class) for feature contributions (when ``pred_contrib=True``).
"""
predictor = self._to_predictor(pred_parameter=deepcopy(kwargs))
predictor = _InnerPredictor.from_booster(
booster=self,
pred_parameter=deepcopy(kwargs),
)
if num_iteration is None:
if start_iteration <= 0:
num_iteration = self.best_iteration
Expand Down Expand Up @@ -4223,7 +4272,10 @@ def refit(
raise LightGBMError('Cannot refit due to null objective function.')
if dataset_params is None:
dataset_params = {}
predictor = self._to_predictor(pred_parameter=deepcopy(kwargs))
predictor = _InnerPredictor.from_booster(
booster=self,
pred_parameter=deepcopy(kwargs)
)
leaf_preds: np.ndarray = predictor.predict( # type: ignore[assignment]
data=data,
start_iteration=-1,
Expand Down Expand Up @@ -4327,15 +4379,6 @@ def set_leaf_output(
)
return self

def _to_predictor(
self,
pred_parameter: Dict[str, Any]
) -> _InnerPredictor:
"""Convert to predictor."""
predictor = _InnerPredictor(booster_handle=self._handle, pred_parameter=pred_parameter)
predictor.pandas_categorical = self.pandas_categorical
return predictor

def num_feature(self) -> int:
"""Get number of features.
Expand Down
26 changes: 21 additions & 5 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,10 +183,20 @@ def train(

predictor: Optional[_InnerPredictor] = None
if isinstance(init_model, (str, Path)):
predictor = _InnerPredictor(model_file=init_model, pred_parameter=params)
predictor = _InnerPredictor.from_model_file(
model_file=init_model,
pred_parameter=params
)
elif isinstance(init_model, Booster):
predictor = init_model._to_predictor(pred_parameter=dict(init_model.params, **params))
init_iteration = predictor.num_total_iteration if predictor is not None else 0
predictor = _InnerPredictor.from_booster(
booster=init_model,
pred_parameter=dict(init_model.params, **params)
)

if predictor is not None:
init_iteration = predictor.current_iteration()
else:
init_iteration = 0

train_set._update_params(params) \
._set_predictor(predictor) \
Expand Down Expand Up @@ -685,9 +695,15 @@ def cv(
first_metric_only = params.get('first_metric_only', False)

if isinstance(init_model, (str, Path)):
predictor = _InnerPredictor(model_file=init_model, pred_parameter=params)
predictor = _InnerPredictor.from_model_file(
model_file=init_model,
pred_parameter=params
)
elif isinstance(init_model, Booster):
predictor = init_model._to_predictor(pred_parameter=dict(init_model.params, **params))
predictor = _InnerPredictor.from_booster(
booster=init_model,
pred_parameter=dict(init_model.params, **params)
)
else:
predictor = None

Expand Down

0 comments on commit 7d4d897

Please sign in to comment.