Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] allow custom weighing in fobj for scikit-learn API (closes #5027) #5211

Merged
merged 3 commits into from
Jun 27, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 10 additions & 13 deletions python-package/lightgbm/sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,10 @@ def __init__(self, func: _LGBM_ScikitCustomObjectiveFunction):
Parameters
----------
func : callable
Expects a callable with signature ``func(y_true, y_pred)`` or ``func(y_true, y_pred, group)``
Expects a callable with following signatures:
``func(y_true, y_pred)``,
``func(y_true, y_pred, weight)``
or ``func(y_true, y_pred, weight, group)``
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
StrikerRUS marked this conversation as resolved.
Show resolved Hide resolved
and returns (grad, hess):

y_true : numpy 1-D array of shape = [n_samples]
Expand All @@ -63,6 +66,8 @@ def __init__(self, func: _LGBM_ScikitCustomObjectiveFunction):
The predicted values.
Predicted values are returned before any transformation,
e.g. they are raw margin instead of probability of positive class for binary task.
weight : numpy 1-D array of shape = [n_samples]
The weight of samples. Weights should be non-negative.
group : numpy 1-D array
Group/query data.
Only used in the learning-to-rank task.
Expand Down Expand Up @@ -107,19 +112,11 @@ def __call__(self, preds, dataset):
if argc == 2:
grad, hess = self.func(labels, preds)
elif argc == 3:
grad, hess = self.func(labels, preds, dataset.get_group())
grad, hess = self.func(labels, preds, dataset.get_weight())
elif argc == 4:
grad, hess = self.func(labels, preds, dataset.get_weight(), dataset.get_group())
else:
raise TypeError(f"Self-defined objective function should have 2 or 3 arguments, got {argc}")
"""weighted for objective"""
weight = dataset.get_weight()
if weight is not None:
if grad.ndim == 2: # multi-class
num_data = grad.shape[0]
if weight.size != num_data:
raise ValueError("grad and hess should be of shape [n_samples, n_classes]")
weight = weight.reshape(num_data, 1)
grad *= weight
hess *= weight
raise TypeError(f"Self-defined objective function should have 2, 3 or 4 arguments, got {argc}")
return grad, hess


Expand Down
25 changes: 20 additions & 5 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -2420,14 +2420,20 @@ def test_default_objective_and_metric():
assert len(evals_result['valid_0']['l2']) == 5


def test_multiclass_custom_objective():
@pytest.mark.parametrize('use_weight', [True, False])
def test_multiclass_custom_objective(use_weight):
def custom_obj(y_pred, ds):
y_true = ds.get_label()
return sklearn_multiclass_custom_objective(y_true, y_pred)
weight = ds.get_weight()
grad, hess = sklearn_multiclass_custom_objective(y_true, y_pred, weight)
return grad, hess

centers = [[-4, -4], [4, 4], [-4, 4]]
X, y = make_blobs(n_samples=1_000, centers=centers, random_state=42)
weight = np.full_like(y, 2)
ds = lgb.Dataset(X, y)
if use_weight:
ds.set_weight(weight)
params = {'objective': 'multiclass', 'num_class': 3, 'num_leaves': 7}
builtin_obj_bst = lgb.train(params, ds, num_boost_round=10)
builtin_obj_preds = builtin_obj_bst.predict(X)
Expand All @@ -2439,16 +2445,25 @@ def custom_obj(y_pred, ds):
np.testing.assert_allclose(builtin_obj_preds, custom_obj_preds, rtol=0.01)


def test_multiclass_custom_eval():
@pytest.mark.parametrize('use_weight', [True, False])
def test_multiclass_custom_eval(use_weight):
def custom_eval(y_pred, ds):
y_true = ds.get_label()
return 'custom_logloss', log_loss(y_true, y_pred), False
weight = ds.get_weight() # weight is None when not set
loss = log_loss(y_true, y_pred, sample_weight=weight)
return 'custom_logloss', loss, False

centers = [[-4, -4], [4, 4], [-4, 4]]
X, y = make_blobs(n_samples=1_000, centers=centers, random_state=42)
X_train, X_valid, y_train, y_valid = train_test_split(X, y, test_size=0.2, random_state=0)
weight = np.full_like(y, 2)
X_train, X_valid, y_train, y_valid, weight_train, weight_valid = train_test_split(
X, y, weight, test_size=0.2, random_state=0
)
train_ds = lgb.Dataset(X_train, y_train)
valid_ds = lgb.Dataset(X_valid, y_valid, reference=train_ds)
if use_weight:
train_ds.set_weight(weight_train)
valid_ds.set_weight(weight_valid)
params = {'objective': 'multiclass', 'num_class': 3, 'num_leaves': 7}
eval_result = {}
bst = lgb.train(
Expand Down
52 changes: 49 additions & 3 deletions tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -1273,18 +1273,64 @@ def test_training_succeeds_when_data_is_dataframe_and_label_is_column_array(task
np.testing.assert_array_equal(preds_1d, preds_2d)


def test_multiclass_custom_objective():
@pytest.mark.parametrize('use_weight', [True, False])
def test_multiclass_custom_objective(use_weight):
centers = [[-4, -4], [4, 4], [-4, 4]]
X, y = make_blobs(n_samples=1_000, centers=centers, random_state=42)
weight = np.full_like(y, 2) if use_weight else None
params = {'n_estimators': 10, 'num_leaves': 7}
builtin_obj_model = lgb.LGBMClassifier(**params)
builtin_obj_model.fit(X, y)
builtin_obj_model.fit(X, y, sample_weight=weight)
builtin_obj_preds = builtin_obj_model.predict_proba(X)

custom_obj_model = lgb.LGBMClassifier(objective=sklearn_multiclass_custom_objective, **params)
custom_obj_model.fit(X, y)
custom_obj_model.fit(X, y, sample_weight=weight)
custom_obj_preds = softmax(custom_obj_model.predict(X, raw_score=True))

np.testing.assert_allclose(builtin_obj_preds, custom_obj_preds, rtol=0.01)
assert not callable(builtin_obj_model.objective_)
assert callable(custom_obj_model.objective_)


@pytest.mark.parametrize('use_weight', [True, False])
def test_multiclass_custom_eval(use_weight):
def custom_eval(y_true, y_pred, weight):
loss = log_loss(y_true, y_pred, sample_weight=weight)
return 'custom_logloss', loss, False

centers = [[-4, -4], [4, 4], [-4, 4]]
X, y = make_blobs(n_samples=1_000, centers=centers, random_state=42)
if use_weight:
weight = np.full_like(y, 2)
X_train, X_valid, y_train, y_valid, weight_train, weight_valid = train_test_split(
X, y, weight, test_size=0.2, random_state=0
)
else:
X_train, X_valid, y_train, y_valid = train_test_split(
X, y, test_size=0.2, random_state=0
)
weight_train = None
weight_valid = None
jmoralez marked this conversation as resolved.
Show resolved Hide resolved
params = {'objective': 'multiclass', 'num_class': 3, 'num_leaves': 7}
eval_result = {}
model = lgb.LGBMClassifier(**params)
model.fit(
X_train,
y_train,
sample_weight=weight_train,
eval_set=[(X_train, y_train), (X_valid, y_valid)],
eval_names=['train', 'valid'],
eval_sample_weight=[weight_train, weight_valid],
eval_metric=custom_eval,
callbacks=[lgb.record_evaluation(eval_result)],
)
jmoralez marked this conversation as resolved.
Show resolved Hide resolved

train_ds = (X_train, y_train, weight_train)
valid_ds = (X_valid, y_valid, weight_valid)
for key, (X, y_true, weight) in zip(['train', 'valid'], [train_ds, valid_ds]):
np.testing.assert_allclose(
eval_result[key]['multi_logloss'], eval_result[key]['custom_logloss']
)
y_pred = model.predict_proba(X)
_, metric_value, _ = custom_eval(y_true, y_pred, weight)
np.testing.assert_allclose(metric_value, eval_result[key]['custom_logloss'][-1])
6 changes: 5 additions & 1 deletion tests/python_package_test/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,14 +140,18 @@ def logistic_sigmoid(x):
return 1.0 / (1.0 + np.exp(-x))


def sklearn_multiclass_custom_objective(y_true, y_pred):
def sklearn_multiclass_custom_objective(y_true, y_pred, weight=None):
num_rows, num_class = y_pred.shape
prob = softmax(y_pred)
grad_update = np.zeros_like(prob)
grad_update[np.arange(num_rows), y_true.astype(np.int32)] = -1.0
grad = prob + grad_update
factor = num_class / (num_class - 1)
hess = factor * prob * (1 - prob)
if weight is not None:
weight2d = weight.reshape(-1, 1)
grad *= weight2d
hess *= weight2d
return grad, hess


Expand Down