From 0cff4f83814393f36b3dac07868a851b44256c01 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Wed, 8 Mar 2023 22:29:25 -0600 Subject: [PATCH] [python-package] add type hints for functions accepting dtypes (#5773) --- python-package/lightgbm/basic.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/python-package/lightgbm/basic.py b/python-package/lightgbm/basic.py index 5c278b29db8c..1f40dc98559a 100644 --- a/python-package/lightgbm/basic.py +++ b/python-package/lightgbm/basic.py @@ -240,7 +240,7 @@ def _is_numpy_column_array(data: Any) -> bool: return len(shape) == 2 and shape[1] == 1 -def _cast_numpy_array_to_dtype(array: np.ndarray, dtype: np.dtype) -> np.ndarray: +def _cast_numpy_array_to_dtype(array: np.ndarray, dtype: "np.typing.DTypeLike") -> np.ndarray: """Cast numpy array to given dtype.""" if array.dtype == dtype: return array @@ -264,7 +264,7 @@ def _is_1d_collection(data: Any) -> bool: def _list_to_1d_numpy( data: Any, - dtype=np.float32, + dtype: "np.typing.DTypeLike" = np.float32, name: str = 'list' ) -> np.ndarray: """Convert data to numpy 1-D array.""" @@ -303,7 +303,11 @@ def _is_2d_collection(data: Any) -> bool: ) -def _data_to_2d_numpy(data: Any, dtype: type = np.float32, name: str = 'list') -> np.ndarray: +def _data_to_2d_numpy( + data: Any, + dtype: "np.typing.DTypeLike" = np.float32, + name: str = 'list' +) -> np.ndarray: """Convert data to numpy 2-D array.""" if _is_numpy_2d_array(data): return _cast_numpy_array_to_dtype(data, dtype) @@ -612,7 +616,7 @@ def _c_int_array(data): return (ptr_data, type_data, data) # return `data` to avoid the temporary copy is freed -def _is_allowed_numpy_dtype(dtype) -> bool: +def _is_allowed_numpy_dtype(dtype: type) -> bool: float128 = getattr(np, 'float128', type(None)) return ( issubclass(dtype, (np.integer, np.floating, np.bool_))