Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] support customizing Dataset creation in Booster.refit() (fixes #3038) #4894

Merged
merged 12 commits into from
Jan 22, 2022
18 changes: 12 additions & 6 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3503,7 +3503,7 @@ def predict(self, data, start_iteration=0, num_iteration=None,
raw_score, pred_leaf, pred_contrib,
data_has_header, is_reshape)

def refit(self, data, label, decay_rate=0.9, **kwargs):
def refit(self, data, label, decay_rate=0.9, kwargs_for_predict=None, kwargs_for_dataset=None):
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
"""Refit the existing Booster by new data.

Parameters
Expand All @@ -3516,9 +3516,11 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
decay_rate : float, optional (default=0.9)
Decay rate of refit,
will use ``leaf_output = decay_rate * old_leaf_output + (1.0 - decay_rate) * new_leaf_output`` to refit trees.
**kwargs
Other parameters for refit.
These parameters will be passed to ``predict`` method.
kwargs_for_predict: dict, optional (default=None)
parameters passed to ``predict`` method.
kwargs_for_dataset: dict, optional (default=None)
additional parameters passed to ``Dataset`` class. If the parameters ``data, label, params`` are contained, they
are removed.

Returns
-------
Expand All @@ -3527,7 +3529,9 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
"""
if self.__set_objective_to_none:
raise LightGBMError('Cannot refit due to null objective function.')
predictor = self._to_predictor(deepcopy(kwargs))
kwargs_for_predict = {} if kwargs_for_predict is None else kwargs_for_predict
kwargs_for_dataset = {} if kwargs_for_dataset is None else kwargs_for_dataset
predictor = self._to_predictor(deepcopy(kwargs_for_predict))
leaf_preds = predictor.predict(data, -1, pred_leaf=True)
nrow, ncol = leaf_preds.shape
out_is_linear = ctypes.c_int(0)
Expand All @@ -3540,7 +3544,9 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
default_value=None
)
new_params["linear_tree"] = bool(out_is_linear.value)
train_set = Dataset(data, label, params=new_params)
for arg in ['data', 'label', 'params']:
kwargs_for_dataset.pop(arg, None)
train_set = Dataset(data, label, params=new_params, **kwargs_for_dataset)
new_params['refit_decay_rate'] = decay_rate
new_booster = Booster(new_params, train_set)
# Copy models
Expand Down
51 changes: 51 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,24 @@ def categorize(continuous_x):
return np.digitize(continuous_x, bins=np.arange(0, 1, 0.01))


@pytest.fixture
def artifacts_for_refit_kwargs():
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
X = np.array([1, 2, 2]).reshape((3, 1))
label = np.array([1, 2, 3])
data = lgb.basic.Dataset(X, label)
booster = lgb.engine.train(
{
"min_data_in_bin": 1,
"min_data_in_leaf": 1,
"learning_rate": 1,
"boost_from_average": False,
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
},
data,
num_boost_round=2,
)
return (X, label, booster)


def test_binary():
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
Expand Down Expand Up @@ -1545,6 +1563,39 @@ def test_refit():
assert err_pred > new_err_pred


def test_refit_kwargs_for_predict(artifacts_for_refit_kwargs):
# check refit accepts kwargs_for_predict
X, label, booster = artifacts_for_refit_kwargs
kwargs_for_dataset = {
"weight": [1.0, 0.0, 1.0],
"reference": None,
"group": None,
"init_score": None,
"feature_name": "auto",
"categorical_feature": "auto",
"free_raw_data": True
}
booster_refit = booster.refit(
X, label, kwargs_for_dataset=kwargs_for_dataset
)
pred = booster_refit.predict(X)
assert pred.shape == (3, )
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved


def test_refit_kwargs_for_dataset(artifacts_for_refit_kwargs):
# check refit accepts kwargs_for_dataset
X, label, booster = artifacts_for_refit_kwargs
kwargs_for_predict = {
"num_iteration": 0,
"raw_score": False,
}
booster_refit = booster.refit(
X, label, kwargs_for_predict=kwargs_for_predict
)
pred = booster_refit.predict(X)
assert pred.shape == (3, )

jameslamb marked this conversation as resolved.
Show resolved Hide resolved

def test_mape_rf():
X, y = load_boston(return_X_y=True)
params = {
Expand Down