Skip to content

Commit

Permalink
[python-package] move validation up earlier in cv() and train() (#5836)
Browse files Browse the repository at this point in the history
  • Loading branch information
jameslamb authored Apr 19, 2023
1 parent fd921d5 commit f74875e
Show file tree
Hide file tree
Showing 2 changed files with 51 additions and 10 deletions.
29 changes: 19 additions & 10 deletions python-package/lightgbm/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,6 +141,20 @@ def train(
booster : Booster
The trained Booster model.
"""
if not isinstance(train_set, Dataset):
raise TypeError(f"train() only accepts Dataset object, train_set has type '{type(train_set).__name__}'.")

if num_boost_round <= 0:
raise ValueError(f"num_boost_round must be greater than 0. Got {num_boost_round}.")

if isinstance(valid_sets, list):
for i, valid_item in enumerate(valid_sets):
if not isinstance(valid_item, Dataset):
raise TypeError(
"Every item in valid_sets must be a Dataset object. "
f"Item {i} has type '{type(valid_item).__name__}'."
)

# create predictor first
params = copy.deepcopy(params)
params = _choose_param_value(
Expand All @@ -167,17 +181,12 @@ def train(
params.pop("early_stopping_round")
first_metric_only = params.get('first_metric_only', False)

if num_boost_round <= 0:
raise ValueError("num_boost_round should be greater than zero.")
predictor: Optional[_InnerPredictor] = None
if isinstance(init_model, (str, Path)):
predictor = _InnerPredictor(model_file=init_model, pred_parameter=params)
elif isinstance(init_model, Booster):
predictor = init_model._to_predictor(pred_parameter=dict(init_model.params, **params))
init_iteration = predictor.num_total_iteration if predictor is not None else 0
# check dataset
if not isinstance(train_set, Dataset):
raise TypeError("Training only accepts Dataset object")

train_set._update_params(params) \
._set_predictor(predictor) \
Expand All @@ -200,8 +209,6 @@ def train(
if valid_names is not None:
train_data_name = valid_names[i]
continue
if not isinstance(valid_data, Dataset):
raise TypeError("Training only accepts Dataset object")
reduced_valid_sets.append(valid_data._update_params(params).set_reference(train_set))
if valid_names is not None and len(valid_names) > i:
name_valid_sets.append(valid_names[i])
Expand Down Expand Up @@ -647,7 +654,11 @@ def cv(
If ``return_cvbooster=True``, also returns trained boosters via ``cvbooster`` key.
"""
if not isinstance(train_set, Dataset):
raise TypeError("Training only accepts Dataset object")
raise TypeError(f"cv() only accepts Dataset object, train_set has type '{type(train_set).__name__}'.")

if num_boost_round <= 0:
raise ValueError(f"num_boost_round must be greater than 0. Got {num_boost_round}.")

params = copy.deepcopy(params)
params = _choose_param_value(
main_param_name='objective',
Expand All @@ -673,8 +684,6 @@ def cv(
params.pop("early_stopping_round")
first_metric_only = params.get('first_metric_only', False)

if num_boost_round <= 0:
raise ValueError("num_boost_round should be greater than zero.")
if isinstance(init_model, (str, Path)):
predictor = _InnerPredictor(model_file=init_model, pred_parameter=params)
elif isinstance(init_model, Booster):
Expand Down
32 changes: 32 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4017,6 +4017,38 @@ def test_validate_features():
bst.refit(df2, y, validate_features=False)


def test_train_and_cv_raise_informative_error_for_train_set_of_wrong_type():
with pytest.raises(TypeError, match=r"train\(\) only accepts Dataset object, train_set has type 'list'\."):
lgb.train({}, train_set=[])
with pytest.raises(TypeError, match=r"cv\(\) only accepts Dataset object, train_set has type 'list'\."):
lgb.cv({}, train_set=[])


@pytest.mark.parametrize('num_boost_round', [-7, -1, 0])
def test_train_and_cv_raise_informative_error_for_impossible_num_boost_round(num_boost_round):
X, y = make_synthetic_regression(n_samples=100)
error_msg = rf"num_boost_round must be greater than 0\. Got {num_boost_round}\."
with pytest.raises(ValueError, match=error_msg):
lgb.train({}, train_set=lgb.Dataset(X, y), num_boost_round=num_boost_round)
with pytest.raises(ValueError, match=error_msg):
lgb.cv({}, train_set=lgb.Dataset(X, y), num_boost_round=num_boost_round)


def test_train_raises_informative_error_if_any_valid_sets_are_not_dataset_objects():
X, y = make_synthetic_regression(n_samples=100)
X_valid = X * 2.0
with pytest.raises(TypeError, match=r"Every item in valid_sets must be a Dataset object\. Item 1 has type 'tuple'\."):
lgb.train(
params={},
train_set=lgb.Dataset(X, y),
valid_sets=[
lgb.Dataset(X_valid, y),
([1.0], [2.0]),
[5.6, 5.7, 5.8]
]
)


def test_train_raises_informative_error_for_params_of_wrong_type():
X, y = make_synthetic_regression()
params = {"early_stopping_round": "too-many"}
Expand Down

0 comments on commit f74875e

Please sign in to comment.