Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] support customizing Dataset creation in Booster.refit() (fixes #3038) #4894

Merged
merged 12 commits into from
Jan 22, 2022
57 changes: 55 additions & 2 deletions python-package/lightgbm/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3503,7 +3503,21 @@ def predict(self, data, start_iteration=0, num_iteration=None,
raw_score, pred_leaf, pred_contrib,
data_has_header, is_reshape)

def refit(self, data, label, decay_rate=0.9, **kwargs):
def refit(
self,
data,
label,
decay_rate=0.9,
reference=None,
weight=None,
group=None,
init_score=None,
feature_name='auto',
categorical_feature='auto',
dataset_params=None,
free_raw_data=True,
**kwargs
):
"""Refit the existing Booster by new data.

Parameters
Expand All @@ -3516,6 +3530,31 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
decay_rate : float, optional (default=0.9)
Decay rate of refit,
will use ``leaf_output = decay_rate * old_leaf_output + (1.0 - decay_rate) * new_leaf_output`` to refit trees.
reference : Dataset or None, optional (default=None)
reference for ``data``.
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
weight : list, numpy 1-D array, pandas Series or None, optional (default=None)
Weight for each ``data`` instance. Weight should be non-negative values because the Hessian
value multiplied by weight is supposed to be non-negative.
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
group : list, numpy 1-D array, pandas Series or None, optional (default=None)
Group/query size for ``data``.
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
init_score : list, numpy 1-D array, pandas Series or None, optional (default=None)
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
Init score for ``data``.
feature_name : list of strings or 'auto', optional (default="auto")
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
Feature names for ``data``.
If 'auto' and data is pandas DataFrame, data columns names are used.
categorical_feature : list of strings or int, or 'auto', optional (default="auto")
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
Categorical features for ``data``.
If list of int, interpreted as indices.
If list of strings, interpreted as feature names (need to specify ``feature_name`` as well).
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
If 'auto' and data is pandas DataFrame, pandas unordered categorical columns are used.
All values in categorical features should be less than int32 max value (2147483647).
Large values could be memory consuming. Consider using consecutive integers starting from zero.
All negative values in categorical features will be treated as missing values.
The output cannot be monotonically constrained with respect to a categorical feature.
dataset_params : dict or None, optional (default=None)
Other parameters for Dataset ``data``.
free_raw_data : bool, optional (default=True)
If True, raw data is freed after constructing inner Dataset for ``data``.
**kwargs
Other parameters for refit.
These parameters will be passed to ``predict`` method.
Expand All @@ -3527,6 +3566,8 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
"""
if self.__set_objective_to_none:
raise LightGBMError('Cannot refit due to null objective function.')
if dataset_params is None:
dataset_params = {}
predictor = self._to_predictor(deepcopy(kwargs))
leaf_preds = predictor.predict(data, -1, pred_leaf=True)
nrow, ncol = leaf_preds.shape
Expand All @@ -3540,7 +3581,19 @@ def refit(self, data, label, decay_rate=0.9, **kwargs):
default_value=None
)
new_params["linear_tree"] = bool(out_is_linear.value)
train_set = Dataset(data, label, params=new_params)
new_params.update(dataset_params)
train_set = Dataset(
data=data,
label=label,
reference=reference,
weight=weight,
group=group,
init_score=init_score,
feature_name=feature_name,
categorical_feature=categorical_feature,
params=new_params,
free_raw_data=free_raw_data,
)
new_params['refit_decay_rate'] = decay_rate
new_booster = Booster(new_params, train_set)
# Copy models
Expand Down
29 changes: 29 additions & 0 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1545,6 +1545,35 @@ def test_refit():
assert err_pred > new_err_pred


def test_refit_dataset_params():
# check refit accepts dataset_params
X, y = load_breast_cancer(return_X_y=True)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=42)
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
lgb_train = lgb.Dataset(X_train, y_train)
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
train_params = {
'objective': 'binary',
'metric': 'binary_logloss',
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
'verbose': -1,
'min_data': 10,
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
'seed': 123
}
gbm = lgb.train(train_params, lgb_train, num_boost_round=20)
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
non_weight_err_pred = log_loss(y_test, gbm.predict(X_test))
dataset_params = {
'max_bin': 260,
'min_data_in_bin': 5,
'data_random_seed': 123,
}
new_gbm = gbm.refit(
data=X_train,
label=y_train,
weight=np.random.rand(y_train.shape[0]),
dataset_params=dataset_params,
)
TremaMiguel marked this conversation as resolved.
Show resolved Hide resolved
weight_err_pred = log_loss(y_test, new_gbm.predict(X_test))
assert weight_err_pred != non_weight_err_pred

jameslamb marked this conversation as resolved.
Show resolved Hide resolved

def test_mape_rf():
X, y = load_boston(return_X_y=True)
params = {
Expand Down