Skip to content

Commit

Permalink
[python-package] add type hints on Dataset feature processing (#5745)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored Feb 26, 2023
1 parent 0007343 commit 39ed8ea
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 16 deletions.
32 changes: 21 additions & 11 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@
_LGBM_EvalFunctionResultType = Tuple[str, float, bool]
_LGBM_BoosterBestScoreType = Dict[str, Dict[str, float]]
_LGBM_BoosterEvalMethodResultType = Tuple[str, str, float, bool]
_LGBM_CategoricalFeatureConfiguration = Union[List[str], List[int], str]
_LGBM_FeatureNameConfiguration = Union[List[str], str]
_LGBM_LabelType = Union[
list,
np.ndarray,
Expand Down Expand Up @@ -588,7 +590,12 @@ def _check_for_bad_pandas_dtypes(pandas_dtypes_series: pd_Series) -> None:
f'Fields with bad pandas dtypes: {", ".join(bad_pandas_dtypes)}')


def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorical):
def _data_from_pandas(
data,
feature_name: Optional[_LGBM_FeatureNameConfiguration],
categorical_feature: Optional[_LGBM_CategoricalFeatureConfiguration],
pandas_categorical: Optional[List[List]]
):
if isinstance(data, pd_DataFrame):
if len(data.shape) != 2 or data.shape[0] < 1:
raise ValueError('Input data must be 2 dimensional and non empty.')
Expand Down Expand Up @@ -638,7 +645,10 @@ def _data_from_pandas(data, feature_name, categorical_feature, pandas_categorica
return data, feature_name, categorical_feature, pandas_categorical


def _dump_pandas_categorical(pandas_categorical, file_name=None):
def _dump_pandas_categorical(
pandas_categorical: Optional[List[List]],
file_name: Optional[Union[str, Path]] = None
) -> str:
categorical_json = json.dumps(pandas_categorical, default=_json_default_with_numpy)
pandas_str = f'\npandas_categorical:{categorical_json}\n'
if file_name is not None:
Expand All @@ -650,7 +660,7 @@ def _dump_pandas_categorical(pandas_categorical, file_name=None):
def _load_pandas_categorical(
file_name: Optional[Union[str, Path]] = None,
model_str: Optional[str] = None
) -> Optional[str]:
) -> Optional[List[List]]:
pandas_key = 'pandas_categorical:'
offset = -len(pandas_key)
if file_name is not None:
Expand Down Expand Up @@ -1320,8 +1330,8 @@ def __init__(
weight=None,
group=None,
init_score=None,
feature_name='auto',
categorical_feature='auto',
feature_name: _LGBM_FeatureNameConfiguration = 'auto',
categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto',
params: Optional[Dict[str, Any]] = None,
free_raw_data: bool = True
):
Expand Down Expand Up @@ -1371,8 +1381,8 @@ def __init__(
self.weight = weight
self.group = group
self.init_score = init_score
self.feature_name = feature_name
self.categorical_feature = categorical_feature
self.feature_name: _LGBM_FeatureNameConfiguration = feature_name
self.categorical_feature: _LGBM_CategoricalFeatureConfiguration = categorical_feature
self.params = deepcopy(params)
self.free_raw_data = free_raw_data
self.used_indices: Optional[List[int]] = None
Expand Down Expand Up @@ -2294,13 +2304,13 @@ def get_field(self, field_name: str) -> Optional[np.ndarray]:

def set_categorical_feature(
self,
categorical_feature: Union[List[int], List[str], str]
categorical_feature: _LGBM_CategoricalFeatureConfiguration
) -> "Dataset":
"""Set categorical features.
Parameters
----------
categorical_feature : list of int or str
categorical_feature : list of str or int, or 'auto'
Names or indices of categorical features.
Returns
Expand Down Expand Up @@ -3937,8 +3947,8 @@ def refit(
weight=None,
group=None,
init_score=None,
feature_name: Union[str, List[str]] = 'auto',
categorical_feature: Union[str, List[str], List[int]] = 'auto',
feature_name: _LGBM_FeatureNameConfiguration = 'auto',
categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto',
dataset_params: Optional[Dict[str, Any]] = None,
free_raw_data: bool = True,
validate_features: bool = False,
Expand Down
11 changes: 6 additions & 5 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,8 @@

from . import callback
from .basic import (Booster, Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _InnerPredictor,
_LGBM_CustomObjectiveFunction, _log_warning)
_LGBM_CategoricalFeatureConfiguration, _LGBM_CustomObjectiveFunction,
_LGBM_FeatureNameConfiguration, _log_warning)
from .compat import SKLEARN_INSTALLED, _LGBMBaseCrossValidator, _LGBMGroupKFold, _LGBMStratifiedKFold

__all__ = [
Expand Down Expand Up @@ -40,8 +41,8 @@ def train(
valid_names: Optional[List[str]] = None,
feval: Optional[Union[_LGBM_CustomMetricFunction, List[_LGBM_CustomMetricFunction]]] = None,
init_model: Optional[Union[str, Path, Booster]] = None,
feature_name: Union[List[str], str] = 'auto',
categorical_feature: Union[List[str], List[int], str] = 'auto',
feature_name: _LGBM_FeatureNameConfiguration = 'auto',
categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto',
keep_training_booster: bool = False,
callbacks: Optional[List[Callable]] = None
) -> Booster:
Expand Down Expand Up @@ -523,8 +524,8 @@ def cv(
metrics: Optional[Union[str, List[str]]] = None,
feval: Optional[Union[_LGBM_CustomMetricFunction, List[_LGBM_CustomMetricFunction]]] = None,
init_model: Optional[Union[str, Path, Booster]] = None,
feature_name: Union[str, List[str]] = 'auto',
categorical_feature: Union[str, List[str], List[int]] = 'auto',
feature_name: _LGBM_FeatureNameConfiguration = 'auto',
categorical_feature: _LGBM_CategoricalFeatureConfiguration = 'auto',
fpreproc: Optional[_LGBM_PreprocFunction] = None,
seed: int = 0,
callbacks: Optional[List[Callable]] = None,
Expand Down

0 comments on commit 39ed8ea

Please sign in to comment.