Skip to content

Commit

Permalink
[python-package] [ci] fix mypy errors in Booster.__inner_predict() (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored Apr 25, 2023
1 parent ef5acfb commit 8670013
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3110,7 +3110,7 @@ def __init__(
ctypes.byref(out_num_class)))
self.__num_class = out_num_class.value
# buffer for inner predict
self.__inner_predict_buffer = [None]
self.__inner_predict_buffer: List[Optional[np.ndarray]] = [None]
self.__is_predicted_cur_iter = [False]
self.__get_eval_info()
self.pandas_categorical = train_set.pandas_categorical
Expand Down Expand Up @@ -4518,16 +4518,16 @@ def __inner_predict(self, data_idx: int) -> np.ndarray:
# avoid to predict many time in one iteration
if not self.__is_predicted_cur_iter[data_idx]:
tmp_out_len = ctypes.c_int64(0)
data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_double))
data_ptr = self.__inner_predict_buffer[data_idx].ctypes.data_as(ctypes.POINTER(ctypes.c_double)) # type: ignore[union-attr]
_safe_call(_LIB.LGBM_BoosterGetPredict(
self.handle,
ctypes.c_int(data_idx),
ctypes.byref(tmp_out_len),
data_ptr))
if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]):
if tmp_out_len.value != len(self.__inner_predict_buffer[data_idx]): # type: ignore[arg-type]
raise ValueError(f"Wrong length of predict results for data {data_idx}")
self.__is_predicted_cur_iter[data_idx] = True
result = self.__inner_predict_buffer[data_idx]
result: np.ndarray = self.__inner_predict_buffer[data_idx] # type: ignore[assignment]
if self.__num_class > 1:
num_data = result.size // self.__num_class
result = result.reshape(num_data, self.__num_class, order='F')
Expand Down

0 comments on commit 8670013

Please sign in to comment.