Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python] add type hints for custom objective and metric functions in scikit-learn interface #4547

Merged
merged 8 commits into from
Nov 15, 2021
23 changes: 12 additions & 11 deletions python-package/lightgbm/dask.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from copy import deepcopy
from enum import Enum, auto
from functools import partial
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union
from typing import Any, Dict, Iterable, List, Optional, Tuple, Type, Union
from urllib.parse import urlparse

import numpy as np
Expand All @@ -21,8 +21,9 @@
from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat,
dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series,
default_client, delayed, pd_DataFrame, pd_Series, wait)
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _lgbmmodel_doc_custom_eval_note,
_lgbmmodel_doc_fit, _lgbmmodel_doc_predict)
from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _LGBM_ScikitCustomEvalFunction,
_LGBM_ScikitCustomObjectiveFunction, _lgbmmodel_doc_custom_eval_note, _lgbmmodel_doc_fit,
_lgbmmodel_doc_predict)

_DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series]
_DaskMatrixLike = Union[dask_Array, dask_DataFrame]
Expand Down Expand Up @@ -400,7 +401,7 @@ def _train(
eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None,
eval_group: Optional[List[_DaskCollection]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Optional[Iterable[int]] = None,
**kwargs: Any
) -> LGBMModel:
Expand Down Expand Up @@ -1029,7 +1030,7 @@ def _lgb_dask_fit(
eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None,
eval_group: Optional[List[_DaskCollection]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Optional[Iterable[int]] = None,
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
Expand Down Expand Up @@ -1096,7 +1097,7 @@ def __init__(
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
objective: Optional[Union[_LGBM_ScikitCustomObjectiveFunction, str]] = None,
jameslamb marked this conversation as resolved.
Show resolved Hide resolved
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
Expand Down Expand Up @@ -1165,7 +1166,7 @@ def fit(
eval_sample_weight: Optional[List[_DaskCollection]] = None,
eval_class_weight: Optional[List[Union[dict, str]]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
) -> "DaskLGBMClassifier":
Expand Down Expand Up @@ -1276,7 +1277,7 @@ def __init__(
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
objective: Optional[Union[_LGBM_ScikitCustomObjectiveFunction, str]] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
Expand Down Expand Up @@ -1343,7 +1344,7 @@ def fit(
eval_names: Optional[List[str]] = None,
eval_sample_weight: Optional[List[_DaskCollection]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
) -> "DaskLGBMRegressor":
Expand Down Expand Up @@ -1436,7 +1437,7 @@ def __init__(
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[Callable, str]] = None,
objective: Optional[Union[_LGBM_ScikitCustomObjectiveFunction, str]] = None,
class_weight: Optional[Union[dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
Expand Down Expand Up @@ -1506,7 +1507,7 @@ def fit(
eval_sample_weight: Optional[List[_DaskCollection]] = None,
eval_init_score: Optional[List[_DaskCollection]] = None,
eval_group: Optional[List[_DaskCollection]] = None,
eval_metric: Optional[Union[Callable, str, List[Union[Callable, str]]]] = None,
eval_metric: Optional[Union[_LGBM_ScikitCustomEvalFunction, str, List[Union[_LGBM_ScikitCustomEvalFunction, str]]]] = None,
eval_at: Iterable[int] = (1, 2, 3, 4, 5),
early_stopping_rounds: Optional[int] = None,
**kwargs: Any
Expand Down
40 changes: 35 additions & 5 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,52 @@
"""Scikit-learn wrapper interface for LightGBM."""
import copy
from inspect import signature
from typing import Callable, Dict, Optional, Union
from typing import Callable, Dict, List, Optional, Tuple, Union

import numpy as np
import scipy.sparse as ss

from .basic import Dataset, LightGBMError, _choose_param_value, _ConfigAliases, _log_warning
from .compat import (SKLEARN_INSTALLED, LGBMNotFittedError, _LGBMAssertAllFinite, _LGBMCheckArray,
_LGBMCheckClassificationTargets, _LGBMCheckSampleWeight, _LGBMCheckXY, _LGBMClassifierBase,
_LGBMComputeSampleWeight, _LGBMLabelEncoder, _LGBMModelBase, _LGBMRegressorBase, dt_DataTable,
pd_DataFrame)
pd_DataFrame, pd_Series)
from .engine import train

_ArrayLike = Union[np.ndarray, ss.spmatrix, pd_Series]
_EvalResultType = Tuple[str, float, bool]
_GroupType = Union[_ArrayLike, List[int]]

_LGBM_ScikitCustomObjectiveFunction = Union[
Callable[
[_ArrayLike, _ArrayLike],
Tuple[np.ndarray, np.ndarray]
],
Callable[
[_ArrayLike, _ArrayLike, _GroupType],
Tuple[np.ndarray, np.ndarray]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please clarify why are you distinguishing all these three types (_ArrayLike, _GroupType, np.ndarray)? They are all documented as array-like.

y_true : array-like of shape = [n_samples]
The target values.
y_pred : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
The predicted values.
Predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task.
group : array-like
Group/query data.
Only used in the learning-to-rank task.
sum(group) = n_samples.
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
grad : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
The value of the first order derivative (gradient) of the loss
with respect to the elements of y_pred for each sample point.
hess : array-like of shape = [n_samples] or shape = [n_samples * n_classes] (for multi-class task)
The value of the second order derivative (Hessian) of the loss
with respect to the elements of y_pred for each sample point.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I didn't realize that y_true and y_pred could be lists, I thought they had to be a pandas Series, numpy array, or scipy matrix.

For grad and hess, it seems that they cannot be scipy matrices or pandas DataFrames / Series (although I didn't realize they could be lists)

grad, hess = fobj(self.__inner_predict(0), self.train_set)
return self.__boost(grad, hess)

grad : list or numpy 1-D array
The value of the first order derivative (gradient) of the loss
with respect to the elements of score for each sample point.
hess : list or numpy 1-D array
The value of the second order derivative (Hessian) of the loss
with respect to the elements of score for each sample point.

To be honest, I'm pretty unsure about the meaning of "array-like" in different parts of LightGBM's docs and I'm not always sure which combinations of these are supported when I see that:

  • list
  • numpy array
  • scipy sparse matrix
  • h2o datatable
  • pandas DataFrame
  • pandas Series

So I took a best guess based on a quick look through the code, but I probably need to test all of those combinations and then updated this PR / the docs as appropriate.

Copy link
Collaborator

@StrikerRUS StrikerRUS Aug 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, absolutely agree with that array-like everywhere in the sklearn-wrapper looks confusing. I might be wrong, but it was written before scikit-learn introduced a formal definition of array-like term:
https://scikit-learn.org/stable/glossary.html#term-array-like

All these values in custom function signatures are supposed to have exactly 1 dimension, right? I believe it will be safe for now assign them the following types which we treat as 1-d array internally

raise TypeError(f"Wrong type({type(data).__name__}) for {name}.\n"
"It should be list, numpy 1-D array or pandas Series")

For grad and hess that function list_to_1d_numpy is applied directly.

grad = list_to_1d_numpy(grad, name='gradient')
hess = list_to_1d_numpy(hess, name='hessian')

For weight and group only np.ndarray is possible, if I'm not mistaken:

elif argc == 3:
return self.func(labels, preds, dataset.get_weight())
elif argc == 4:
return self.func(labels, preds, dataset.get_weight(), dataset.get_group())

def get_weight(self):
"""Get the weight of the Dataset.
Returns
-------
weight : numpy array or None
Weight for each data point from the Dataset.
"""
if self.weight is None:
self.weight = self.get_field('weight')
return self.weight

def get_group(self):
"""Get the group of the Dataset.
Returns
-------
group : numpy array or None
Group/query data.
Only used in the learning-to-rank task.
sum(group) = n_samples.
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
"""
if self.group is None:
self.group = self.get_field('group')
if self.group is not None:
# group data from LightGBM is boundaries data, need to convert to group size
self.group = np.diff(self.group)
return self.group

if weight is not None:
self.set_weight(weight)
if group is not None:
self.set_group(group)

def set_weight(self, weight):
"""Set weight of each instance.
Parameters
----------
weight : list, numpy 1-D array, pandas Series or None
Weight to be set for each data point.
Returns
-------
self : Dataset
Dataset with set weight.
"""
if weight is not None and np.all(weight == 1):
weight = None
self.weight = weight
if self.handle is not None and weight is not None:
weight = list_to_1d_numpy(weight, name='weight')
self.set_field('weight', weight)
self.weight = self.get_field('weight') # original values can be modified at cpp side
return self

def set_group(self, group):
"""Set group size of Dataset (used for ranking).
Parameters
----------
group : list, numpy 1-D array, pandas Series or None
Group/query data.
Only used in the learning-to-rank task.
sum(group) = n_samples.
For example, if you have a 100-document dataset with ``group = [10, 20, 40, 10, 10, 10]``, that means that you have 6 groups,
where the first 10 records are in the first group, records 11-30 are in the second group, records 31-70 are in the third group, etc.
Returns
-------
self : Dataset
Dataset with set group.
"""
self.group = group
if self.handle is not None and group is not None:
group = list_to_1d_numpy(group, np.int32, name='group')
self.set_field('group', group)
return self

For y_true the same logic is applicable as for weight and group.

labels = dataset.get_label()

For y_pred only np.ndarray is possible

grad, hess = fobj(self.__inner_predict(0), self.train_set)

feval_ret = eval_function(self.__inner_predict(data_idx), cur_data)

return self.__inner_predict_buffer[data_idx]

self.__inner_predict_buffer[data_idx] = np.empty(n_preds, dtype=np.float64)

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for taking so long to get back to this one!

I just pushed ea1aada with my best understanding of your comments above, but to be honest I still am confused about exactly what is allowed.

Here is my interpretation of those comments / links:

  • eval function
    • y_true = list, numpy array, or pandas Series
    • y_pred = numpy array
    • group = numpy array
    • weight = numpy array
  • objective function
    • y_true = list, numpy array, or pandas Series
    • y_pred = numpy array
    • group = numpy array
    • grad (output) = list, numpy array, or pandas Series
    • hess (output) = list, numpy array, or pandas Series

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Double-checked this and I think your interpretation is fine. Thanks for the detailed investigation.

],
]
_LGBM_ScikitCustomEvalFunction = Union[
Callable[
[_ArrayLike, _ArrayLike],
Union[_EvalResultType, List[_EvalResultType]]
],
Callable[
[_ArrayLike, _ArrayLike, _ArrayLike],
Union[_EvalResultType, List[_EvalResultType]]
],
Callable[
[_ArrayLike, _ArrayLike, _ArrayLike, _GroupType],
Union[_EvalResultType, List[_EvalResultType]]
],
]


class _ObjectiveFunctionWrapper:
"""Proxy class for objective function."""

def __init__(self, func):
def __init__(self, func: _LGBM_ScikitCustomObjectiveFunction):
"""Construct a proxy class.

This class transforms objective function to match objective function with signature ``new_func(preds, dataset)``
Expand Down Expand Up @@ -106,7 +136,7 @@ def __call__(self, preds, dataset):
class _EvalFunctionWrapper:
"""Proxy class for evaluation function."""

def __init__(self, func):
def __init__(self, func: _LGBM_ScikitCustomEvalFunction):
"""Construct a proxy class.

This class transforms evaluation function to match evaluation function with signature ``new_func(preds, dataset)``
Expand Down Expand Up @@ -357,7 +387,7 @@ def __init__(
learning_rate: float = 0.1,
n_estimators: int = 100,
subsample_for_bin: int = 200000,
objective: Optional[Union[str, Callable]] = None,
objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None,
class_weight: Optional[Union[Dict, str]] = None,
min_split_gain: float = 0.,
min_child_weight: float = 1e-3,
Expand Down