From c3b7a233a322803db1f89385f18606b5259179d1 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Wed, 29 Dec 2021 16:24:43 -0600 Subject: [PATCH 01/12] add test for custom objective with regressor --- python-package/lightgbm/dask.py | 8 ++-- tests/python_package_test/test_dask.py | 55 ++++++++++++++++++++++++++ 2 files changed, 59 insertions(+), 4 deletions(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 062422286a47..95e636693db1 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -21,7 +21,7 @@ from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat, dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series, default_client, delayed, pd_DataFrame, pd_Series, wait) -from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _LGBM_ScikitCustomEvalFunction, +from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _LGBM_ScikitCustomEvalFunction, _LGBM_ScikitCustomObjectiveFunction, _lgbmmodel_doc_custom_eval_note, _lgbmmodel_doc_fit, _lgbmmodel_doc_predict) _DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series] @@ -1099,7 +1099,7 @@ def __init__( learning_rate: float = 0.1, n_estimators: int = 100, subsample_for_bin: int = 200000, - objective: Optional[str] = None, + objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None, class_weight: Optional[Union[dict, str]] = None, min_split_gain: float = 0., min_child_weight: float = 1e-3, @@ -1275,7 +1275,7 @@ def __init__( learning_rate: float = 0.1, n_estimators: int = 100, subsample_for_bin: int = 200000, - objective: Optional[str] = None, + objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None, class_weight: Optional[Union[dict, str]] = None, min_split_gain: float = 0., min_child_weight: float = 1e-3, @@ -1431,7 +1431,7 @@ def __init__( learning_rate: float = 0.1, n_estimators: int = 100, subsample_for_bin: int = 200000, - objective: Optional[str] = None, + objective: Optional[Union[str, _LGBM_ScikitCustomObjectiveFunction]] = None, class_weight: Optional[Union[dict, str]] = None, min_split_gain: float = 0., min_child_weight: float = 1e-3, diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index b4a948070420..0fffe3d80036 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -262,6 +262,12 @@ def _unpickle(filepath, serializer): raise ValueError(f'Unrecognized serializer type: {serializer}') +def _objective_least_squares(y_true, y_pred): + grad = (y_pred - y_true) + hess = np.ones(len(y_true)) + return grad, hess + + @pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification']) @pytest.mark.parametrize('boosting_type', boosting_types) @@ -700,6 +706,55 @@ def test_regressor_quantile(output, alpha, cluster): assert tree_df.loc[node_uses_cat_col, "decision_type"].unique()[0] == '==' +@pytest.mark.parametrize('output', data_output) +def test_regressor_custom_objective(output, cluster): + with Client(cluster) as client: + X, y, w, _, dX, dy, dw, _ = _create_data( + objective='regression', + output=output + ) + + params = { + "n_estimators": 10, + "num_leaves": 10, + "objective": _objective_least_squares + } + + dask_regressor = lgb.DaskLGBMRegressor( + client=client, + time_out=5, + tree_learner='data', + **params + ) + dask_regressor = dask_regressor.fit(dX, dy, sample_weight=dw) + dask_regressor_local = dask_regressor.to_local() + p1 = dask_regressor.predict(dX) + p1_local = dask_regressor_local.predict(X) + s1_local = dask_regressor_local.score(X, y) + s1 = _r2_score(dy, p1) + p1 = p1.compute() + + local_regressor = lgb.LGBMRegressor(**params) + local_regressor.fit(X, y, sample_weight=w) + p2 = local_regressor.predict(X) + s2 = local_regressor.score(X, y) + + # function should have been preserved + assert callable(dask_regressor.objective) + assert callable(dask_regressor_local.objective) + + # Scores should be the same + assert_eq(s1, s2, atol=0.01) + assert_eq(s1, s1_local) + + # local and Dask predictions should be the same + assert_eq(p1, p1_local) + + # predictions should be better than random + assert_eq(p1, y, rtol=0.5, atol=50.) + assert_eq(p2, y, rtol=0.5, atol=50.) + + @pytest.mark.parametrize('output', ['array', 'dataframe', 'dataframe-with-categorical']) @pytest.mark.parametrize('group', [None, group_sizes]) @pytest.mark.parametrize('boosting_type', boosting_types) From b7b754e867e234d6f2670266bf48c3c7500f7e4b Mon Sep 17 00:00:00 2001 From: James Lamb Date: Wed, 29 Dec 2021 20:35:07 -0600 Subject: [PATCH 02/12] add test for custom binary classification objective with classifier --- tests/python_package_test/test_dask.py | 65 ++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 0fffe3d80036..0136fa97a42c 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -267,6 +267,12 @@ def _objective_least_squares(y_true, y_pred): hess = np.ones(len(y_true)) return grad, hess +def _objective_logistic_regression(y_true, y_pred): + y_pred = 1.0 / (1.0 + np.exp(-y_pred)) + grad = y_pred - y_true + hess = y_pred * (1.0 - y_pred) + return grad, hess + @pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification']) @@ -461,6 +467,65 @@ def test_classifier_pred_contrib(output, task, cluster): assert len(np.unique(preds_with_contrib[:, base_value_col]) == 1) +@pytest.mark.parametrize('output', data_output) +def test_classifier_binary_classification_custom_objective(output, cluster): + with Client(cluster) as client: + X, y, w, _, dX, dy, dw, _ = _create_data( + objective='binary-classification', + output=output + ) + + params = { + "n_estimators": 50, + "num_leaves": 31, + "min_data": 1, + "verbose": -1, + "objective": _objective_logistic_regression + } + + dask_classifier = lgb.DaskLGBMClassifier( + client=client, + time_out=5, + tree_learner='data', + **params + ) + dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw) + dask_classifier_local = dask_classifier.to_local() + p1 = dask_classifier.predict(dX) + p1_proba = dask_classifier.predict_proba(dX).compute() + p1_local = dask_classifier_local.predict(X) + # with a custom objective, predictiion result is a raw score instead of predicted class + p1_class = (1.0 / (1.0 + np.exp(-p1_proba))) > 0.5 + p1_class = p1_class.astype('int64') + p1_proba_local = dask_classifier_local.predict_proba(X) + p1_class_local = (1.0 / (1.0 + np.exp(-p1_proba_local))) > 0.5 + p1_class_local = p1_class_local.astype('int64') + p1 = p1.compute() + + local_classifier = lgb.LGBMClassifier(**params) + local_classifier.fit(X, y, sample_weight=w) + p2 = local_classifier.predict(X) + p2_proba = local_classifier.predict_proba(X) + p2_class = (1.0 / (1.0 + np.exp(-p1_proba))) > 0.5 + p2_class = p2_class.astype('int64') + + # function should have been preserved + assert callable(dask_classifier.objective) + assert callable(dask_classifier_local.objective) + + # should correctly classify every sample + assert_eq(p1, p2) + assert_eq(p1_class, y) + assert_eq(p2_class, y) + + # probability estimates should be similar + assert_eq(p1_proba, p2_proba, atol=0.03) + + # predictions from to_local() model should be identical to those from LGBMClassifier + assert_eq(p1_local, p2) + assert_eq(p1_class_local, y) + + def test_group_workers_by_host(): hosts = [f'0.0.0.{i}' for i in range(2)] workers = [f'tcp://{host}:{p}' for p in range(2) for host in hosts] From 82d1ed48b96f22f2da6231d1e10234fee3d27336 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Wed, 29 Dec 2021 20:50:55 -0600 Subject: [PATCH 03/12] isort --- python-package/lightgbm/dask.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 95e636693db1..483b82358d13 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -21,8 +21,9 @@ from .compat import (DASK_INSTALLED, PANDAS_INSTALLED, SKLEARN_INSTALLED, Client, LGBMNotFittedError, concat, dask_Array, dask_array_from_delayed, dask_bag_from_delayed, dask_DataFrame, dask_Series, default_client, delayed, pd_DataFrame, pd_Series, wait) -from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _LGBM_ScikitCustomEvalFunction, _LGBM_ScikitCustomObjectiveFunction, - _lgbmmodel_doc_custom_eval_note, _lgbmmodel_doc_fit, _lgbmmodel_doc_predict) +from .sklearn import (LGBMClassifier, LGBMModel, LGBMRanker, LGBMRegressor, _LGBM_ScikitCustomEvalFunction, + _LGBM_ScikitCustomObjectiveFunction, _lgbmmodel_doc_custom_eval_note, _lgbmmodel_doc_fit, + _lgbmmodel_doc_predict) _DaskCollection = Union[dask_Array, dask_DataFrame, dask_Series] _DaskMatrixLike = Union[dask_Array, dask_DataFrame] From 0a029a7ee05fb7690b1fbe6df06994aedeb7ab2e Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 30 Dec 2021 01:27:20 -0600 Subject: [PATCH 04/12] got tests working for multiclass --- tests/python_package_test/test_dask.py | 120 +++++++++++++++++++++---- 1 file changed, 105 insertions(+), 15 deletions(-) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 0136fa97a42c..1339fe772421 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -267,6 +267,7 @@ def _objective_least_squares(y_true, y_pred): hess = np.ones(len(y_true)) return grad, hess + def _objective_logistic_regression(y_true, y_pred): y_pred = 1.0 / (1.0 + np.exp(-y_pred)) grad = y_pred - y_true @@ -274,6 +275,24 @@ def _objective_logistic_regression(y_true, y_pred): return grad, hess +def _objective_logloss(y_true, y_pred): + num_rows = len(y_true) + num_class = len(np.unique(y_true)) + # operate on preds as [num_data, num_classes] matrix + y_pred = y_pred.T.reshape(-1, num_class) + row_wise_max = np.max(y_pred, axis=1).reshape(num_rows, 1) + preds = y_pred - row_wise_max + prob = np.exp(preds) / np.sum(np.exp(preds), axis=1).reshape(num_rows, 1) + grad_update = np.zeros_like(preds) + grad_update[np.arange(num_rows), y_true.astype('int')] = -1.0 + grad = prob + grad_update + hess = 2.0 * prob * (1.0 - prob) + # reshape back to 1-D array, grouped by class id and then row id + grad = grad.T.reshape(-1) + hess = hess.T.reshape(-1) + return grad, hess + + @pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification']) @pytest.mark.parametrize('boosting_type', boosting_types) @@ -468,21 +487,32 @@ def test_classifier_pred_contrib(output, task, cluster): @pytest.mark.parametrize('output', data_output) -def test_classifier_binary_classification_custom_objective(output, cluster): +@pytest.mark.parametrize('task', ['binary-classification', 'multiclass-classification']) +def test_classifier_custom_objective(output, task, cluster): with Client(cluster) as client: X, y, w, _, dX, dy, dw, _ = _create_data( - objective='binary-classification', + objective=task, output=output ) params = { - "n_estimators": 50, - "num_leaves": 31, + "n_estimators": 10, + "num_leaves": 10, "min_data": 1, "verbose": -1, - "objective": _objective_logistic_regression + "learning_rate": 0.01, } + if task == 'binary-classification': + params.update({ + 'objective': _objective_logistic_regression, + }) + elif task == 'multiclass-classification': + params.update({ + 'objective': _objective_logloss, + 'num_classes': 3 + }) + dask_classifier = lgb.DaskLGBMClassifier( client=client, time_out=5, @@ -491,40 +521,38 @@ def test_classifier_binary_classification_custom_objective(output, cluster): ) dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw) dask_classifier_local = dask_classifier.to_local() - p1 = dask_classifier.predict(dX) p1_proba = dask_classifier.predict_proba(dX).compute() - p1_local = dask_classifier_local.predict(X) + p1_proba_local = dask_classifier_local.predict_proba(X) + # with a custom objective, predictiion result is a raw score instead of predicted class p1_class = (1.0 / (1.0 + np.exp(-p1_proba))) > 0.5 p1_class = p1_class.astype('int64') - p1_proba_local = dask_classifier_local.predict_proba(X) p1_class_local = (1.0 / (1.0 + np.exp(-p1_proba_local))) > 0.5 p1_class_local = p1_class_local.astype('int64') - p1 = p1.compute() local_classifier = lgb.LGBMClassifier(**params) local_classifier.fit(X, y, sample_weight=w) - p2 = local_classifier.predict(X) p2_proba = local_classifier.predict_proba(X) p2_class = (1.0 / (1.0 + np.exp(-p1_proba))) > 0.5 p2_class = p2_class.astype('int64') + if task == 'multiclass-classification': + p1_class = p1_class.argmax(axis=1) + p1_class_local = p1_class_local.argmax(axis=1) + p2_class = p2_class.argmax(axis=1) + # function should have been preserved assert callable(dask_classifier.objective) assert callable(dask_classifier_local.objective) # should correctly classify every sample - assert_eq(p1, p2) assert_eq(p1_class, y) + assert_eq(p1_class_local, y) assert_eq(p2_class, y) # probability estimates should be similar assert_eq(p1_proba, p2_proba, atol=0.03) - # predictions from to_local() model should be identical to those from LGBMClassifier - assert_eq(p1_local, p2) - assert_eq(p1_class_local, y) - def test_group_workers_by_host(): hosts = [f'0.0.0.{i}' for i in range(2)] @@ -928,6 +956,68 @@ def test_ranker(output, group, boosting_type, tree_learner, cluster): assert tree_df.loc[node_uses_cat_col, "decision_type"].unique()[0] == '==' +@pytest.mark.parametrize('output', ['array', 'dataframe', 'dataframe-with-categorical']) +def test_ranker_custom_objective(output, cluster): + with Client(cluster) as client: + if output == 'dataframe-with-categorical': + X, y, w, g, dX, dy, dw, dg = _create_data( + objective='ranking', + output=output, + group=group_sizes, + n_features=1, + n_informative=1 + ) + else: + X, y, w, g, dX, dy, dw, dg = _create_data( + objective='ranking', + output=output, + group=group_sizes + ) + + # rebalance small dask.Array dataset for better performance. + if output == 'array': + dX = dX.persist() + dy = dy.persist() + dw = dw.persist() + dg = dg.persist() + _ = wait([dX, dy, dw, dg]) + client.rebalance() + + params = { + "random_state": 42, + "n_estimators": 50, + "num_leaves": 20, + "min_child_samples": 1, + "objective": _objective_least_squares + } + + dask_ranker = lgb.DaskLGBMRanker( + client=client, + time_out=5, + tree_learner_type="data", + **params + ) + dask_ranker = dask_ranker.fit(dX, dy, sample_weight=dw, group=dg) + rnkvec_dask = dask_ranker.predict(dX).compute() + dask_ranker_local = dask_ranker.to_local() + rnkvec_dask_local = dask_ranker_local.predict(X) + + local_ranker = lgb.LGBMRanker(**params) + local_ranker.fit(X, y, sample_weight=w, group=g) + rnkvec_local = local_ranker.predict(X) + + # distributed ranker should be able to rank decently well with the least-squares objective + # and should have high rank correlation with scores from serial ranker. + dcor = spearmanr(rnkvec_dask, y).correlation + assert dcor > 0.6 + assert spearmanr(rnkvec_dask, rnkvec_local).correlation > 0.8 + assert_eq(rnkvec_dask, rnkvec_dask_local) + + # function should have been preserved + assert callable(dask_ranker.objective) + assert callable(dask_ranker_local.objective) + + @pytest.mark.parametrize('task', tasks) @pytest.mark.parametrize('output', data_output) @pytest.mark.parametrize('eval_sizes', [[0.5, 1, 1.5], [0]]) From a1627a507da31b7b2c76d6d39e058863ac832c50 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 30 Dec 2021 01:53:20 -0600 Subject: [PATCH 05/12] update docs --- docs/Parallel-Learning-Guide.rst | 35 ++++++++++++++++++++++++++ python-package/lightgbm/dask.py | 17 +++---------- tests/python_package_test/test_dask.py | 2 +- 3 files changed, 39 insertions(+), 15 deletions(-) diff --git a/docs/Parallel-Learning-Guide.rst b/docs/Parallel-Learning-Guide.rst index 2fe895d5d3a2..f220ebe7c28c 100644 --- a/docs/Parallel-Learning-Guide.rst +++ b/docs/Parallel-Learning-Guide.rst @@ -230,6 +230,41 @@ You could edit your firewall rules to allow communication between any of the wor * the port ``local_listen_port`` is not open on any of the worker hosts * any machine has multiple Dask worker processes running on it +Using Custom Objective Functions with Dask +****************************************** + +It is possible to customize the boosting process by providing a custom objective function written in Python. +See the Dask API's documentation for details on how to implement such functions. + +.. warning:: + + Custom objective functions used with ``lightgbm.dask`` will be called by each worker process on only that worker's local data. + +Follow the example below to use a custom implementation of the ``regression_l2`` objective. + +.. code:: python + + import dask.array as da + import lightgbm as lgb + import numpy as np + from distributed import Client, LocalCluster + + cluster = LocalCluster(n_workers=2) + client = Client(cluster) + + X = da.random.random((1000, 10), (500, 10)) + y = da.random.random((1000,), (500,)) + + def custom_l2_obj(y_true, y_pred): + grad = y_pred - y_true + hess = np.ones(len(y_true)) + return grad, hess + + dask_model = lgb.DaskLGBMRegressor( + objective=custom_l2_obj + ) + dask_model.fit(X, y) + Prediction with Dask '''''''''''''''''''' diff --git a/python-package/lightgbm/dask.py b/python-package/lightgbm/dask.py index 483b82358d13..8acab0a8ae7b 100644 --- a/python-package/lightgbm/dask.py +++ b/python-package/lightgbm/dask.py @@ -1143,16 +1143,12 @@ def __init__( _base_doc = LGBMClassifier.__init__.__doc__ _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') # type: ignore - _base_doc = f""" + __init__.__doc__ = f""" {_before_kwargs}client : dask.distributed.Client or None, optional (default=None) {' ':4}Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled. {_kwargs}{_after_kwargs} """ - # the note on custom objective functions in LGBMModel.__init__ is not - # currently relevant for the Dask estimators - __init__.__doc__ = _base_doc[:_base_doc.find('Note\n')] - def __getstate__(self) -> Dict[Any, Any]: return self._lgb_dask_getstate() @@ -1319,14 +1315,11 @@ def __init__( _base_doc = LGBMRegressor.__init__.__doc__ _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') # type: ignore - _base_doc = f""" + __init__.__doc__ = f""" {_before_kwargs}client : dask.distributed.Client or None, optional (default=None) {' ':4}Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled. {_kwargs}{_after_kwargs} """ - # the note on custom objective functions in LGBMModel.__init__ is not - # currently relevant for the Dask estimators - __init__.__doc__ = _base_doc[:_base_doc.find('Note\n')] def __getstate__(self) -> Dict[Any, Any]: return self._lgb_dask_getstate() @@ -1475,16 +1468,12 @@ def __init__( _base_doc = LGBMRanker.__init__.__doc__ _before_kwargs, _kwargs, _after_kwargs = _base_doc.partition('**kwargs') # type: ignore - _base_doc = f""" + __init__.__doc__ = f""" {_before_kwargs}client : dask.distributed.Client or None, optional (default=None) {' ':4}Dask client. If ``None``, ``distributed.default_client()`` will be used at runtime. The Dask client used by this class will not be saved if the model object is pickled. {_kwargs}{_after_kwargs} """ - # the note on custom objective functions in LGBMModel.__init__ is not - # currently relevant for the Dask estimators - __init__.__doc__ = _base_doc[:_base_doc.find('Note\n')] - def __getstate__(self) -> Dict[Any, Any]: return self._lgb_dask_getstate() diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 1339fe772421..02691d7d4e29 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -263,7 +263,7 @@ def _unpickle(filepath, serializer): def _objective_least_squares(y_true, y_pred): - grad = (y_pred - y_true) + grad = y_pred - y_true hess = np.ones(len(y_true)) return grad, hess From 8b784190e7569be242f516edd9d3a71e480e5f44 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 30 Dec 2021 14:55:24 -0600 Subject: [PATCH 06/12] train deeper model for classifier --- tests/python_package_test/test_dask.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 02691d7d4e29..add178ab81bb 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -496,8 +496,8 @@ def test_classifier_custom_objective(output, task, cluster): ) params = { - "n_estimators": 10, - "num_leaves": 10, + "n_estimators": 50, + "num_leaves": 31, "min_data": 1, "verbose": -1, "learning_rate": 0.01, From 80bd399bff54e6c856404979a4a128785d0b30df Mon Sep 17 00:00:00 2001 From: James Lamb Date: Thu, 30 Dec 2021 20:17:12 -0600 Subject: [PATCH 07/12] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Morales --- tests/python_package_test/test_dask.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index add178ab81bb..f6c5025b0787 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -279,14 +279,15 @@ def _objective_logloss(y_true, y_pred): num_rows = len(y_true) num_class = len(np.unique(y_true)) # operate on preds as [num_data, num_classes] matrix - y_pred = y_pred.T.reshape(-1, num_class) + y_pred = y_pred.reshape(-1, num_class, order='F') row_wise_max = np.max(y_pred, axis=1).reshape(num_rows, 1) preds = y_pred - row_wise_max prob = np.exp(preds) / np.sum(np.exp(preds), axis=1).reshape(num_rows, 1) grad_update = np.zeros_like(preds) grad_update[np.arange(num_rows), y_true.astype('int')] = -1.0 grad = prob + grad_update - hess = 2.0 * prob * (1.0 - prob) + factor = num_class / (num_class - 1) + hess = factor * prob * (1 - prob) # reshape back to 1-D array, grouped by class id and then row id grad = grad.T.reshape(-1) hess = hess.T.reshape(-1) From 98f4132de6f92e3b4198eb38c1471475df10c5c1 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 4 Jan 2022 18:56:30 -0600 Subject: [PATCH 08/12] Apply suggestions from code review Co-authored-by: Nikita Titov --- tests/python_package_test/test_dask.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index f6c5025b0787..8910a5a3b210 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -284,7 +284,7 @@ def _objective_logloss(y_true, y_pred): preds = y_pred - row_wise_max prob = np.exp(preds) / np.sum(np.exp(preds), axis=1).reshape(num_rows, 1) grad_update = np.zeros_like(preds) - grad_update[np.arange(num_rows), y_true.astype('int')] = -1.0 + grad_update[np.arange(num_rows), y_true.astype(np.int32)] = -1.0 grad = prob + grad_update factor = num_class / (num_class - 1) hess = factor * prob * (1 - prob) @@ -525,17 +525,17 @@ def test_classifier_custom_objective(output, task, cluster): p1_proba = dask_classifier.predict_proba(dX).compute() p1_proba_local = dask_classifier_local.predict_proba(X) - # with a custom objective, predictiion result is a raw score instead of predicted class + # with a custom objective, prediction result is a raw score instead of predicted class p1_class = (1.0 / (1.0 + np.exp(-p1_proba))) > 0.5 - p1_class = p1_class.astype('int64') + p1_class = p1_class.astype(np.int64) p1_class_local = (1.0 / (1.0 + np.exp(-p1_proba_local))) > 0.5 - p1_class_local = p1_class_local.astype('int64') + p1_class_local = p1_class_local.astype(np.int64) local_classifier = lgb.LGBMClassifier(**params) local_classifier.fit(X, y, sample_weight=w) p2_proba = local_classifier.predict_proba(X) p2_class = (1.0 / (1.0 + np.exp(-p1_proba))) > 0.5 - p2_class = p2_class.astype('int64') + p2_class = p2_class.astype(np.int64) if task == 'multiclass-classification': p1_class = p1_class.argmax(axis=1) @@ -543,8 +543,8 @@ def test_classifier_custom_objective(output, task, cluster): p2_class = p2_class.argmax(axis=1) # function should have been preserved - assert callable(dask_classifier.objective) - assert callable(dask_classifier_local.objective) + assert callable(dask_classifier.objective_) + assert callable(dask_classifier_local.objective_) # should correctly classify every sample assert_eq(p1_class, y) @@ -834,8 +834,8 @@ def test_regressor_custom_objective(output, cluster): s2 = local_regressor.score(X, y) # function should have been preserved - assert callable(dask_regressor.objective) - assert callable(dask_regressor_local.objective) + assert callable(dask_regressor.objective_) + assert callable(dask_regressor_local.objective_) # Scores should be the same assert_eq(s1, s2, atol=0.01) @@ -1015,8 +1015,8 @@ def test_ranker_custom_objective(output, cluster): assert_eq(rnkvec_dask, rnkvec_dask_local) # function should have been preserved - assert callable(dask_ranker.objective) - assert callable(dask_ranker_local.objective) + assert callable(dask_ranker.objective_) + assert callable(dask_ranker_local.objective_) @pytest.mark.parametrize('task', tasks) From 54df09e734cbc98f261bf4a108e781e41d591914 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Wed, 12 Jan 2022 21:16:09 -0600 Subject: [PATCH 09/12] update multiclass tests --- tests/python_package_test/test_dask.py | 36 ++++++++++++++------------ 1 file changed, 19 insertions(+), 17 deletions(-) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 8910a5a3b210..3ff84f6329ac 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -493,15 +493,16 @@ def test_classifier_custom_objective(output, task, cluster): with Client(cluster) as client: X, y, w, _, dX, dy, dw, _ = _create_data( objective=task, - output=output + output=output, ) params = { "n_estimators": 50, "num_leaves": 31, - "min_data": 1, "verbose": -1, - "learning_rate": 0.01, + "seed": 708, + "deterministic": True, + "force_col_wise": True } if task == 'binary-classification': @@ -522,25 +523,26 @@ def test_classifier_custom_objective(output, task, cluster): ) dask_classifier = dask_classifier.fit(dX, dy, sample_weight=dw) dask_classifier_local = dask_classifier.to_local() - p1_proba = dask_classifier.predict_proba(dX).compute() - p1_proba_local = dask_classifier_local.predict_proba(X) + p1_raw = dask_classifier.predict(dX, raw_score=True).compute() + p1_raw_local = dask_classifier_local.predict(X, raw_score=True) # with a custom objective, prediction result is a raw score instead of predicted class - p1_class = (1.0 / (1.0 + np.exp(-p1_proba))) > 0.5 - p1_class = p1_class.astype(np.int64) - p1_class_local = (1.0 / (1.0 + np.exp(-p1_proba_local))) > 0.5 - p1_class_local = p1_class_local.astype(np.int64) + p1_proba = 1.0 / (1.0 + np.exp(-p1_raw)) + p1_proba_local = 1.0 / (1.0 + np.exp(-p1_raw_local)) local_classifier = lgb.LGBMClassifier(**params) local_classifier.fit(X, y, sample_weight=w) - p2_proba = local_classifier.predict_proba(X) - p2_class = (1.0 / (1.0 + np.exp(-p1_proba))) > 0.5 - p2_class = p2_class.astype(np.int64) + p2_raw = local_classifier.predict(X, raw_score=True) + p2_proba = 1.0 / (1.0 + np.exp(-p2_raw)) - if task == 'multiclass-classification': - p1_class = p1_class.argmax(axis=1) - p1_class_local = p1_class_local.argmax(axis=1) - p2_class = p2_class.argmax(axis=1) + if task == 'binary-classification': + p1_class = (p1_proba > 0.5).astype(np.int64) + p1_class_local = (p1_proba_local > 0.5).astype(np.int64) + p2_class = (p2_proba > 0.5).astype(np.int64) + elif task == 'multiclass-classification': + p1_class = p1_proba.argmax(axis=1) + p1_class_local = p1_proba_local.argmax(axis=1) + p2_class = p2_proba.argmax(axis=1) # function should have been preserved assert callable(dask_classifier.objective_) @@ -552,7 +554,7 @@ def test_classifier_custom_objective(output, task, cluster): assert_eq(p2_class, y) # probability estimates should be similar - assert_eq(p1_proba, p2_proba, atol=0.03) + assert_eq(p1_proba, p2_proba, atol=0.04) def test_group_workers_by_host(): From f05c9b6d3c5ac55f8cc214716a602d282ffcd664 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 14 Jan 2022 11:55:49 -0600 Subject: [PATCH 10/12] Apply suggestions from code review Co-authored-by: Nikita Titov --- tests/python_package_test/test_dask.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 3ff84f6329ac..181d06459742 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -555,6 +555,7 @@ def test_classifier_custom_objective(output, task, cluster): # probability estimates should be similar assert_eq(p1_proba, p2_proba, atol=0.04) + assert_eq(p1_proba, p1_proba_local) def test_group_workers_by_host(): @@ -847,8 +848,9 @@ def test_regressor_custom_objective(output, cluster): assert_eq(p1, p1_local) # predictions should be better than random - assert_eq(p1, y, rtol=0.5, atol=50.) - assert_eq(p2, y, rtol=0.5, atol=50.) + assert_precision = {"rtol": 0.5, "atol": 50.} + assert_eq(p1, y, **assert_precision) + assert_eq(p2, y, **assert_precision) @pytest.mark.parametrize('output', ['array', 'dataframe', 'dataframe-with-categorical']) @@ -1011,8 +1013,7 @@ def test_ranker_custom_objective(output, cluster): # distributed ranker should be able to rank decently well with the least-squares objective # and should have high rank correlation with scores from serial ranker. - dcor = spearmanr(rnkvec_dask, y).correlation - assert dcor > 0.6 + assert spearmanr(rnkvec_dask, y).correlation > 0.6 assert spearmanr(rnkvec_dask, rnkvec_local).correlation > 0.8 assert_eq(rnkvec_dask, rnkvec_dask_local) From 2e2752cbf4975e71f9945c8e8ac18047dd234abb Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 14 Jan 2022 19:55:05 -0600 Subject: [PATCH 11/12] fix multiclass probabilities --- tests/python_package_test/test_dask.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 181d06459742..5cefd25fbbd1 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -526,22 +526,24 @@ def test_classifier_custom_objective(output, task, cluster): p1_raw = dask_classifier.predict(dX, raw_score=True).compute() p1_raw_local = dask_classifier_local.predict(X, raw_score=True) - # with a custom objective, prediction result is a raw score instead of predicted class - p1_proba = 1.0 / (1.0 + np.exp(-p1_raw)) - p1_proba_local = 1.0 / (1.0 + np.exp(-p1_raw_local)) - local_classifier = lgb.LGBMClassifier(**params) local_classifier.fit(X, y, sample_weight=w) p2_raw = local_classifier.predict(X, raw_score=True) - p2_proba = 1.0 / (1.0 + np.exp(-p2_raw)) + # with a custom objective, prediction result is a raw score instead of predicted class if task == 'binary-classification': + p1_proba = 1.0 / (1.0 + np.exp(-p1_raw)) p1_class = (p1_proba > 0.5).astype(np.int64) + p1_proba_local = 1.0 / (1.0 + np.exp(-p1_raw_local)) p1_class_local = (p1_proba_local > 0.5).astype(np.int64) + p2_proba = 1.0 / (1.0 + np.exp(-p2_raw)) p2_class = (p2_proba > 0.5).astype(np.int64) elif task == 'multiclass-classification': + p1_proba = np.exp(p1_raw)/np.sum(np.exp(p1_raw), axis=1).reshape(-1, 1) p1_class = p1_proba.argmax(axis=1) + p1_proba_local = np.exp(p1_raw_local)/np.sum(np.exp(p1_raw_local), axis=1).reshape(-1, 1) p1_class_local = p1_proba_local.argmax(axis=1) + p2_proba = np.exp(p2_raw)/np.sum(np.exp(p2_raw), axis=1).reshape(-1, 1) p2_class = p2_proba.argmax(axis=1) # function should have been preserved @@ -554,7 +556,7 @@ def test_classifier_custom_objective(output, task, cluster): assert_eq(p2_class, y) # probability estimates should be similar - assert_eq(p1_proba, p2_proba, atol=0.04) + assert_eq(p1_proba, p2_proba, atol=0.03) assert_eq(p1_proba, p1_proba_local) From c55ed6cac3dd2e34937846e96dd04ae68bb1e8a0 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Fri, 14 Jan 2022 19:59:09 -0600 Subject: [PATCH 12/12] linting --- tests/python_package_test/test_dask.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/python_package_test/test_dask.py b/tests/python_package_test/test_dask.py index 5cefd25fbbd1..5812e5761134 100644 --- a/tests/python_package_test/test_dask.py +++ b/tests/python_package_test/test_dask.py @@ -539,11 +539,11 @@ def test_classifier_custom_objective(output, task, cluster): p2_proba = 1.0 / (1.0 + np.exp(-p2_raw)) p2_class = (p2_proba > 0.5).astype(np.int64) elif task == 'multiclass-classification': - p1_proba = np.exp(p1_raw)/np.sum(np.exp(p1_raw), axis=1).reshape(-1, 1) + p1_proba = np.exp(p1_raw) / np.sum(np.exp(p1_raw), axis=1).reshape(-1, 1) p1_class = p1_proba.argmax(axis=1) - p1_proba_local = np.exp(p1_raw_local)/np.sum(np.exp(p1_raw_local), axis=1).reshape(-1, 1) + p1_proba_local = np.exp(p1_raw_local) / np.sum(np.exp(p1_raw_local), axis=1).reshape(-1, 1) p1_class_local = p1_proba_local.argmax(axis=1) - p2_proba = np.exp(p2_raw)/np.sum(np.exp(p2_raw), axis=1).reshape(-1, 1) + p2_proba = np.exp(p2_raw) / np.sum(np.exp(p2_raw), axis=1).reshape(-1, 1) p2_class = p2_proba.argmax(axis=1) # function should have been preserved