diff --git a/tests/python_package_test/test_plotting.py b/tests/python_package_test/test_plotting.py index 293012348ac3..2c25d5733ed0 100644 --- a/tests/python_package_test/test_plotting.py +++ b/tests/python_package_test/test_plotting.py @@ -1,9 +1,8 @@ # coding: utf-8 -import unittest - import lightgbm as lgb from lightgbm.compat import MATPLOTLIB_INSTALLED, GRAPHVIZ_INSTALLED from sklearn.model_selection import train_test_split +import pytest if MATPLOTLIB_INSTALLED: import matplotlib @@ -14,164 +13,188 @@ from .utils import load_breast_cancer -class TestBasic(unittest.TestCase): +@pytest.fixture(scope="module") +def breast_cancer_split(): + return train_test_split(*load_breast_cancer(return_X_y=True), + test_size=0.1, random_state=1) + + +@pytest.fixture(scope="module") +def train_data(breast_cancer_split): + X_train, _, y_train, _ = breast_cancer_split + return lgb.Dataset(X_train, y_train) + - def setUp(self): - self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(*load_breast_cancer(return_X_y=True), - test_size=0.1, random_state=1) - self.train_data = lgb.Dataset(self.X_train, self.y_train) - self.params = { - "objective": "binary", +@pytest.fixture +def params(): + return {"objective": "binary", "verbose": -1, - "num_leaves": 3 - } - - @unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed') - def test_plot_importance(self): - gbm0 = lgb.train(self.params, self.train_data, num_boost_round=10) - ax0 = lgb.plot_importance(gbm0) - self.assertIsInstance(ax0, matplotlib.axes.Axes) - self.assertEqual(ax0.get_title(), 'Feature importance') - self.assertEqual(ax0.get_xlabel(), 'Feature importance') - self.assertEqual(ax0.get_ylabel(), 'Features') - self.assertLessEqual(len(ax0.patches), 30) - - gbm1 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True) - gbm1.fit(self.X_train, self.y_train) - - ax1 = lgb.plot_importance(gbm1, color='r', title='t', xlabel='x', ylabel='y') - self.assertIsInstance(ax1, matplotlib.axes.Axes) - self.assertEqual(ax1.get_title(), 't') - self.assertEqual(ax1.get_xlabel(), 'x') - self.assertEqual(ax1.get_ylabel(), 'y') - self.assertLessEqual(len(ax1.patches), 30) - for patch in ax1.patches: - self.assertTupleEqual(patch.get_facecolor(), (1., 0, 0, 1.)) # red - - ax2 = lgb.plot_importance(gbm0, color=['r', 'y', 'g', 'b'], - title=None, xlabel=None, ylabel=None) - self.assertIsInstance(ax2, matplotlib.axes.Axes) - self.assertEqual(ax2.get_title(), '') - self.assertEqual(ax2.get_xlabel(), '') - self.assertEqual(ax2.get_ylabel(), '') - self.assertLessEqual(len(ax2.patches), 30) - self.assertTupleEqual(ax2.patches[0].get_facecolor(), (1., 0, 0, 1.)) # r - self.assertTupleEqual(ax2.patches[1].get_facecolor(), (.75, .75, 0, 1.)) # y - self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.)) # g - self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.)) # b - - @unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed') - def test_plot_split_value_histogram(self): - gbm0 = lgb.train(self.params, self.train_data, num_boost_round=10) - ax0 = lgb.plot_split_value_histogram(gbm0, 27) - self.assertIsInstance(ax0, matplotlib.axes.Axes) - self.assertEqual(ax0.get_title(), 'Split value histogram for feature with index 27') - self.assertEqual(ax0.get_xlabel(), 'Feature split value') - self.assertEqual(ax0.get_ylabel(), 'Count') - self.assertLessEqual(len(ax0.patches), 2) - - gbm1 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True) - gbm1.fit(self.X_train, self.y_train) - - ax1 = lgb.plot_split_value_histogram(gbm1, gbm1.booster_.feature_name()[27], figsize=(10, 5), - title='Histogram for feature @index/name@ @feature@', - xlabel='x', ylabel='y', color='r') - self.assertIsInstance(ax1, matplotlib.axes.Axes) - self.assertEqual(ax1.get_title(), - 'Histogram for feature name {}'.format(gbm1.booster_.feature_name()[27])) - self.assertEqual(ax1.get_xlabel(), 'x') - self.assertEqual(ax1.get_ylabel(), 'y') - self.assertLessEqual(len(ax1.patches), 2) - for patch in ax1.patches: - self.assertTupleEqual(patch.get_facecolor(), (1., 0, 0, 1.)) # red - - ax2 = lgb.plot_split_value_histogram(gbm0, 27, bins=10, color=['r', 'y', 'g', 'b'], - title=None, xlabel=None, ylabel=None) - self.assertIsInstance(ax2, matplotlib.axes.Axes) - self.assertEqual(ax2.get_title(), '') - self.assertEqual(ax2.get_xlabel(), '') - self.assertEqual(ax2.get_ylabel(), '') - self.assertEqual(len(ax2.patches), 10) - self.assertTupleEqual(ax2.patches[0].get_facecolor(), (1., 0, 0, 1.)) # r - self.assertTupleEqual(ax2.patches[1].get_facecolor(), (.75, .75, 0, 1.)) # y - self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.)) # g - self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.)) # b - - self.assertRaises(ValueError, lgb.plot_split_value_histogram, gbm0, 0) # was not used in splitting - - @unittest.skipIf(not MATPLOTLIB_INSTALLED or not GRAPHVIZ_INSTALLED, 'matplotlib or graphviz is not installed') - def test_plot_tree(self): - gbm = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True) - gbm.fit(self.X_train, self.y_train, verbose=False) - - self.assertRaises(IndexError, lgb.plot_tree, gbm, tree_index=83) - - ax = lgb.plot_tree(gbm, tree_index=3, figsize=(15, 8), show_info=['split_gain']) - self.assertIsInstance(ax, matplotlib.axes.Axes) - w, h = ax.axes.get_figure().get_size_inches() - self.assertEqual(int(w), 15) - self.assertEqual(int(h), 8) - - @unittest.skipIf(not GRAPHVIZ_INSTALLED, 'graphviz is not installed') - def test_create_tree_digraph(self): - constraints = [-1, 1] * int(self.X_train.shape[1] / 2) - gbm = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True, monotone_constraints=constraints) - gbm.fit(self.X_train, self.y_train, verbose=False) - - self.assertRaises(IndexError, lgb.create_tree_digraph, gbm, tree_index=83) - - graph = lgb.create_tree_digraph(gbm, tree_index=3, - show_info=['split_gain', 'internal_value', 'internal_weight'], - name='Tree4', node_attr={'color': 'red'}) - graph.render(view=False) - self.assertIsInstance(graph, graphviz.Digraph) - self.assertEqual(graph.name, 'Tree4') - self.assertEqual(graph.filename, 'Tree4.gv') - self.assertEqual(len(graph.node_attr), 1) - self.assertEqual(graph.node_attr['color'], 'red') - self.assertEqual(len(graph.graph_attr), 0) - self.assertEqual(len(graph.edge_attr), 0) - graph_body = ''.join(graph.body) - self.assertIn('leaf', graph_body) - self.assertIn('gain', graph_body) - self.assertIn('value', graph_body) - self.assertIn('weight', graph_body) - self.assertIn('#ffdddd', graph_body) - self.assertIn('#ddffdd', graph_body) - self.assertNotIn('data', graph_body) - self.assertNotIn('count', graph_body) - - @unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed') - def test_plot_metrics(self): - test_data = lgb.Dataset(self.X_test, self.y_test, reference=self.train_data) - self.params.update({"metric": {"binary_logloss", "binary_error"}}) - - evals_result0 = {} - gbm0 = lgb.train(self.params, self.train_data, - valid_sets=[self.train_data, test_data], - valid_names=['v1', 'v2'], - num_boost_round=10, - evals_result=evals_result0, - verbose_eval=False) - ax0 = lgb.plot_metric(evals_result0) - self.assertIsInstance(ax0, matplotlib.axes.Axes) - self.assertEqual(ax0.get_title(), 'Metric during training') - self.assertEqual(ax0.get_xlabel(), 'Iterations') - self.assertIn(ax0.get_ylabel(), {'binary_logloss', 'binary_error'}) - ax0 = lgb.plot_metric(evals_result0, metric='binary_error') - ax0 = lgb.plot_metric(evals_result0, metric='binary_logloss', dataset_names=['v2']) - - evals_result1 = {} - gbm1 = lgb.train(self.params, self.train_data, - num_boost_round=10, - evals_result=evals_result1, - verbose_eval=False) - self.assertRaises(ValueError, lgb.plot_metric, evals_result1) - - gbm2 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True) - gbm2.fit(self.X_train, self.y_train, eval_set=[(self.X_test, self.y_test)], verbose=False) - ax2 = lgb.plot_metric(gbm2, title=None, xlabel=None, ylabel=None) - self.assertIsInstance(ax2, matplotlib.axes.Axes) - self.assertEqual(ax2.get_title(), '') - self.assertEqual(ax2.get_xlabel(), '') - self.assertEqual(ax2.get_ylabel(), '') + "num_leaves": 3} + + +@pytest.mark.skipif(not MATPLOTLIB_INSTALLED, reason='matplotlib is not installed') +def test_plot_importance(params, breast_cancer_split, train_data): + X_train, _, y_train, _ = breast_cancer_split + + gbm0 = lgb.train(params, train_data, num_boost_round=10) + ax0 = lgb.plot_importance(gbm0) + assert isinstance(ax0, matplotlib.axes.Axes) + assert ax0.get_title() == 'Feature importance' + assert ax0.get_xlabel() == 'Feature importance' + assert ax0.get_ylabel() == 'Features' + assert len(ax0.patches) <= 30 + + gbm1 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True) + gbm1.fit(X_train, y_train) + + ax1 = lgb.plot_importance(gbm1, color='r', title='t', xlabel='x', ylabel='y') + assert isinstance(ax1, matplotlib.axes.Axes) + assert ax1.get_title() == 't' + assert ax1.get_xlabel() == 'x' + assert ax1.get_ylabel() == 'y' + assert len(ax1.patches) <= 30 + for patch in ax1.patches: + assert patch.get_facecolor() == (1., 0, 0, 1.) # red + + ax2 = lgb.plot_importance(gbm0, color=['r', 'y', 'g', 'b'], + title=None, xlabel=None, ylabel=None) + assert isinstance(ax2, matplotlib.axes.Axes) + assert ax2.get_title() == '' + assert ax2.get_xlabel() == '' + assert ax2.get_ylabel() == '' + assert len(ax2.patches) <= 30 + assert ax2.patches[0].get_facecolor() == (1., 0, 0, 1.) # r + assert ax2.patches[1].get_facecolor() == (.75, .75, 0, 1.) # y + assert ax2.patches[2].get_facecolor() == (0, .5, 0, 1.) # g + assert ax2.patches[3].get_facecolor() == (0, 0, 1., 1.) # b + + +@pytest.mark.skipif(not MATPLOTLIB_INSTALLED, reason='matplotlib is not installed') +def test_plot_split_value_histogram(params, breast_cancer_split, train_data): + X_train, _, y_train, _ = breast_cancer_split + + gbm0 = lgb.train(params, train_data, num_boost_round=10) + ax0 = lgb.plot_split_value_histogram(gbm0, 27) + assert isinstance(ax0, matplotlib.axes.Axes) + assert ax0.get_title() == 'Split value histogram for feature with index 27' + assert ax0.get_xlabel() == 'Feature split value' + assert ax0.get_ylabel() == 'Count' + assert len(ax0.patches) <= 2 + + gbm1 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True) + gbm1.fit(X_train, y_train) + + ax1 = lgb.plot_split_value_histogram(gbm1, gbm1.booster_.feature_name()[27], figsize=(10, 5), + title='Histogram for feature @index/name@ @feature@', + xlabel='x', ylabel='y', color='r') + assert isinstance(ax1, matplotlib.axes.Axes) + title = 'Histogram for feature name {}'.format(gbm1.booster_.feature_name()[27]) + assert ax1.get_title() == title + assert ax1.get_xlabel() == 'x' + assert ax1.get_ylabel() == 'y' + assert len(ax1.patches) <= 2 + for patch in ax1.patches: + assert patch.get_facecolor() == (1., 0, 0, 1.) # red + + ax2 = lgb.plot_split_value_histogram(gbm0, 27, bins=10, color=['r', 'y', 'g', 'b'], + title=None, xlabel=None, ylabel=None) + assert isinstance(ax2, matplotlib.axes.Axes) + assert ax2.get_title() == '' + assert ax2.get_xlabel() == '' + assert ax2.get_ylabel() == '' + assert len(ax2.patches) == 10 + assert ax2.patches[0].get_facecolor() == (1., 0, 0, 1.) # r + assert ax2.patches[1].get_facecolor() == (.75, .75, 0, 1.) # y + assert ax2.patches[2].get_facecolor() == (0, .5, 0, 1.) # g + assert ax2.patches[3].get_facecolor() == (0, 0, 1., 1.) # b + + with pytest.raises(ValueError): + lgb.plot_split_value_histogram(gbm0, 0) # was not used in splitting + + +@pytest.mark.skipif(not MATPLOTLIB_INSTALLED or not GRAPHVIZ_INSTALLED, + reason='matplotlib or graphviz is not installed') +def test_plot_tree(breast_cancer_split): + X_train, _, y_train, _ = breast_cancer_split + gbm = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True) + gbm.fit(X_train, y_train, verbose=False) + + with pytest.raises(IndexError): + lgb.plot_tree(gbm, tree_index=83) + + ax = lgb.plot_tree(gbm, tree_index=3, figsize=(15, 8), show_info=['split_gain']) + assert isinstance(ax, matplotlib.axes.Axes) + w, h = ax.axes.get_figure().get_size_inches() + assert int(w) == 15 + assert int(h) == 8 + + +@pytest.mark.skipif(not GRAPHVIZ_INSTALLED, reason='graphviz is not installed') +def test_create_tree_digraph(breast_cancer_split): + X_train, _, y_train, _ = breast_cancer_split + + constraints = [-1, 1] * int(X_train.shape[1] / 2) + gbm = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True, monotone_constraints=constraints) + gbm.fit(X_train, y_train, verbose=False) + + with pytest.raises(IndexError): + lgb.create_tree_digraph(gbm, tree_index=83) + + graph = lgb.create_tree_digraph(gbm, tree_index=3, + show_info=['split_gain', 'internal_value', 'internal_weight'], + name='Tree4', node_attr={'color': 'red'}) + graph.render(view=False) + assert isinstance(graph, graphviz.Digraph) + assert graph.name == 'Tree4' + assert graph.filename == 'Tree4.gv' + assert len(graph.node_attr) == 1 + assert graph.node_attr['color'] == 'red' + assert len(graph.graph_attr) == 0 + assert len(graph.edge_attr) == 0 + graph_body = ''.join(graph.body) + assert 'leaf' in graph_body + assert 'gain' in graph_body + assert 'value' in graph_body + assert 'weight' in graph_body + assert '#ffdddd' in graph_body + assert '#ddffdd' in graph_body + assert 'data' not in graph_body + assert 'count' not in graph_body + + +@pytest.mark.skipif(not MATPLOTLIB_INSTALLED, reason='matplotlib is not installed') +def test_plot_metrics(params, breast_cancer_split, train_data): + X_train, X_test, y_train, y_test = breast_cancer_split + test_data = lgb.Dataset(X_test, y_test, reference=train_data) + params.update({"metric": {"binary_logloss", "binary_error"}}) + + evals_result0 = {} + lgb.train(params, train_data, + valid_sets=[train_data, test_data], + valid_names=['v1', 'v2'], + num_boost_round=10, + evals_result=evals_result0, + verbose_eval=False) + ax0 = lgb.plot_metric(evals_result0) + assert isinstance(ax0, matplotlib.axes.Axes) + assert ax0.get_title() == 'Metric during training' + assert ax0.get_xlabel() == 'Iterations' + assert ax0.get_ylabel() in {'binary_logloss', 'binary_error'} + ax0 = lgb.plot_metric(evals_result0, metric='binary_error') + ax0 = lgb.plot_metric(evals_result0, metric='binary_logloss', dataset_names=['v2']) + + evals_result1 = {} + lgb.train(params, train_data, + num_boost_round=10, + evals_result=evals_result1, + verbose_eval=False) + with pytest.raises(ValueError): + lgb.plot_metric(evals_result1) + + gbm2 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True) + gbm2.fit(X_train, y_train, eval_set=[(X_test, y_test)], verbose=False) + ax2 = lgb.plot_metric(gbm2, title=None, xlabel=None, ylabel=None) + assert isinstance(ax2, matplotlib.axes.Axes) + assert ax2.get_title() == '' + assert ax2.get_xlabel() == '' + assert ax2.get_ylabel() == ''