Skip to content

Commit

Permalink
[python-package] mark EarlyStopException as part of public API (#6095)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored Sep 13, 2023
1 parent 1a6e6ff commit 0b3d9da
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 3 deletions.
4 changes: 2 additions & 2 deletions python-package/lightgbm/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pathlib import Path

from .basic import Booster, Dataset, Sequence, register_logger
from .callback import early_stopping, log_evaluation, record_evaluation, reset_parameter
from .callback import EarlyStopException, early_stopping, log_evaluation, record_evaluation, reset_parameter
from .engine import CVBooster, cv, train

try:
Expand All @@ -32,5 +32,5 @@
'train', 'cv',
'LGBMModel', 'LGBMRegressor', 'LGBMClassifier', 'LGBMRanker',
'DaskLGBMRegressor', 'DaskLGBMClassifier', 'DaskLGBMRanker',
'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping',
'log_evaluation', 'record_evaluation', 'reset_parameter', 'early_stopping', 'EarlyStopException',
'plot_importance', 'plot_split_value_histogram', 'plot_metric', 'plot_tree', 'create_tree_digraph']
8 changes: 7 additions & 1 deletion python-package/lightgbm/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from .engine import CVBooster

__all__ = [
'EarlyStopException',
'early_stopping',
'log_evaluation',
'record_evaluation',
Expand All @@ -30,7 +31,11 @@


class EarlyStopException(Exception):
"""Exception of early stopping."""
"""Exception of early stopping.
Raise this from a callback passed in via keyword argument ``callbacks``
in ``cv()`` or ``train()`` to trigger early stopping.
"""

def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) -> None:
"""Create early stopping exception.
Expand All @@ -39,6 +44,7 @@ def __init__(self, best_iteration: int, best_score: _ListOfEvalResultTuples) ->
----------
best_iteration : int
The best iteration stopped.
0-based... pass ``best_iteration=2`` to indicate that the third iteration was the best one.
best_score : list of (eval_name, metric_name, eval_result, is_higher_better) tuple or (eval_name, metric_name, eval_result, is_higher_better, stdv) tuple
Scores for each metric, on each validation set, as of the best iteration.
"""
Expand Down
27 changes: 27 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1092,6 +1092,33 @@ def test_early_stopping_min_delta(first_only, single_metric, greater_is_better):
assert np.greater_equal(last_score, best_score - min_delta).any()


def test_early_stopping_can_be_triggered_via_custom_callback():
X, y = make_synthetic_regression()

def _early_stop_after_seventh_iteration(env):
if env.iteration == 6:
exc = lgb.EarlyStopException(
best_iteration=6,
best_score=[("some_validation_set", "some_metric", 0.708, True)]
)
raise exc

bst = lgb.train(
params={
"objective": "regression",
"verbose": -1,
"num_leaves": 2
},
train_set=lgb.Dataset(X, label=y),
num_boost_round=23,
callbacks=[_early_stop_after_seventh_iteration]
)
assert bst.num_trees() == 7
assert bst.best_score["some_validation_set"]["some_metric"] == 0.708
assert bst.best_iteration == 7
assert bst.current_iteration() == 7


def test_continue_train():
X, y = make_synthetic_regression()
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
Expand Down

0 comments on commit 0b3d9da

Please sign in to comment.