From 7c9a985a476309ba1d7a7767b5247a3d4eb579b9 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 21 Sep 2023 21:37:03 -0500 Subject: [PATCH] [python-package] fix mypy errors in Dataset construction (#6106) --- python-package/lightgbm/basic.py | 27 ++++++++++++++++++++++++--- 1 file changed, 24 insertions(+), 3 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index e17e8513f7ca..c2964fcedd8d 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -24,6 +24,13 @@ if TYPE_CHECKING: from typing import Literal + # typing.TypeGuard was only introduced in Python 3.10 + try: + from typing import TypeGuard + except ImportError: + from typing_extensions import TypeGuard + + __all__ = [ 'Booster', 'Dataset', @@ -279,6 +286,20 @@ def _is_1d_list(data: Any) -> bool: return isinstance(data, list) and (not data or _is_numeric(data[0])) +def _is_list_of_numpy_arrays(data: Any) -> "TypeGuard[List[np.ndarray]]": + return ( + isinstance(data, list) + and all(isinstance(x, np.ndarray) for x in data) + ) + + +def _is_list_of_sequences(data: Any) -> "TypeGuard[List[Sequence]]": + return ( + isinstance(data, list) + and all(isinstance(x, Sequence) for x in data) + ) + + def _is_1d_collection(data: Any) -> bool: """Check whether data is a 1-D collection.""" return ( @@ -1918,9 +1939,9 @@ def _lazy_init( elif isinstance(data, np.ndarray): self.__init_from_np2d(data, params_str, ref_dataset) elif isinstance(data, list) and len(data) > 0: - if all(isinstance(x, np.ndarray) for x in data): + if _is_list_of_numpy_arrays(data): self.__init_from_list_np2d(data, params_str, ref_dataset) - elif all(isinstance(x, Sequence) for x in data): + elif _is_list_of_sequences(data): self.__init_from_seqs(data, ref_dataset) else: raise TypeError('Data list can only be of ndarray or Sequence') @@ -2870,7 +2891,7 @@ def get_data(self) -> Optional[_LGBM_TrainDataType]: self.data = self.data[self.used_indices, :] elif isinstance(self.data, Sequence): self.data = self.data[self.used_indices] - elif isinstance(self.data, list) and len(self.data) > 0 and all(isinstance(x, Sequence) for x in self.data): + elif _is_list_of_sequences(self.data) and len(self.data) > 0: self.data = np.array(list(self._yield_row_from_seqlist(self.data, self.used_indices))) else: _log_warning(f"Cannot subset {type(self.data).__name__} type of raw data.\n"