Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[python-package] migrate test_plotting.py to pytest #3811

Merged
merged 2 commits into from
Jan 22, 2021
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
345 changes: 184 additions & 161 deletions tests/python_package_test/test_plotting.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,8 @@
# coding: utf-8
import unittest

import lightgbm as lgb
from lightgbm.compat import MATPLOTLIB_INSTALLED, GRAPHVIZ_INSTALLED
from sklearn.model_selection import train_test_split
import pytest

if MATPLOTLIB_INSTALLED:
import matplotlib
Expand All @@ -14,164 +13,188 @@
from .utils import load_breast_cancer


class TestBasic(unittest.TestCase):
@pytest.fixture(scope="module")
def breast_cancer_split():
return train_test_split(*load_breast_cancer(return_X_y=True),
test_size=0.1, random_state=1)


@pytest.fixture(scope="module")
def train_data(breast_cancer_split):
X_train, _, y_train, _ = breast_cancer_split
return lgb.Dataset(X_train, y_train)


def setUp(self):
self.X_train, self.X_test, self.y_train, self.y_test = train_test_split(*load_breast_cancer(return_X_y=True),
test_size=0.1, random_state=1)
self.train_data = lgb.Dataset(self.X_train, self.y_train)
self.params = {
"objective": "binary",
@pytest.fixture
def params():
return {"objective": "binary",
"verbose": -1,
"num_leaves": 3
}

@unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed')
def test_plot_importance(self):
gbm0 = lgb.train(self.params, self.train_data, num_boost_round=10)
ax0 = lgb.plot_importance(gbm0)
self.assertIsInstance(ax0, matplotlib.axes.Axes)
self.assertEqual(ax0.get_title(), 'Feature importance')
self.assertEqual(ax0.get_xlabel(), 'Feature importance')
self.assertEqual(ax0.get_ylabel(), 'Features')
self.assertLessEqual(len(ax0.patches), 30)

gbm1 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
gbm1.fit(self.X_train, self.y_train)

ax1 = lgb.plot_importance(gbm1, color='r', title='t', xlabel='x', ylabel='y')
self.assertIsInstance(ax1, matplotlib.axes.Axes)
self.assertEqual(ax1.get_title(), 't')
self.assertEqual(ax1.get_xlabel(), 'x')
self.assertEqual(ax1.get_ylabel(), 'y')
self.assertLessEqual(len(ax1.patches), 30)
for patch in ax1.patches:
self.assertTupleEqual(patch.get_facecolor(), (1., 0, 0, 1.)) # red

ax2 = lgb.plot_importance(gbm0, color=['r', 'y', 'g', 'b'],
title=None, xlabel=None, ylabel=None)
self.assertIsInstance(ax2, matplotlib.axes.Axes)
self.assertEqual(ax2.get_title(), '')
self.assertEqual(ax2.get_xlabel(), '')
self.assertEqual(ax2.get_ylabel(), '')
self.assertLessEqual(len(ax2.patches), 30)
self.assertTupleEqual(ax2.patches[0].get_facecolor(), (1., 0, 0, 1.)) # r
self.assertTupleEqual(ax2.patches[1].get_facecolor(), (.75, .75, 0, 1.)) # y
self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.)) # g
self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.)) # b

@unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed')
def test_plot_split_value_histogram(self):
gbm0 = lgb.train(self.params, self.train_data, num_boost_round=10)
ax0 = lgb.plot_split_value_histogram(gbm0, 27)
self.assertIsInstance(ax0, matplotlib.axes.Axes)
self.assertEqual(ax0.get_title(), 'Split value histogram for feature with index 27')
self.assertEqual(ax0.get_xlabel(), 'Feature split value')
self.assertEqual(ax0.get_ylabel(), 'Count')
self.assertLessEqual(len(ax0.patches), 2)

gbm1 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
gbm1.fit(self.X_train, self.y_train)

ax1 = lgb.plot_split_value_histogram(gbm1, gbm1.booster_.feature_name()[27], figsize=(10, 5),
title='Histogram for feature @index/name@ @feature@',
xlabel='x', ylabel='y', color='r')
self.assertIsInstance(ax1, matplotlib.axes.Axes)
self.assertEqual(ax1.get_title(),
'Histogram for feature name {}'.format(gbm1.booster_.feature_name()[27]))
self.assertEqual(ax1.get_xlabel(), 'x')
self.assertEqual(ax1.get_ylabel(), 'y')
self.assertLessEqual(len(ax1.patches), 2)
for patch in ax1.patches:
self.assertTupleEqual(patch.get_facecolor(), (1., 0, 0, 1.)) # red

ax2 = lgb.plot_split_value_histogram(gbm0, 27, bins=10, color=['r', 'y', 'g', 'b'],
title=None, xlabel=None, ylabel=None)
self.assertIsInstance(ax2, matplotlib.axes.Axes)
self.assertEqual(ax2.get_title(), '')
self.assertEqual(ax2.get_xlabel(), '')
self.assertEqual(ax2.get_ylabel(), '')
self.assertEqual(len(ax2.patches), 10)
self.assertTupleEqual(ax2.patches[0].get_facecolor(), (1., 0, 0, 1.)) # r
self.assertTupleEqual(ax2.patches[1].get_facecolor(), (.75, .75, 0, 1.)) # y
self.assertTupleEqual(ax2.patches[2].get_facecolor(), (0, .5, 0, 1.)) # g
self.assertTupleEqual(ax2.patches[3].get_facecolor(), (0, 0, 1., 1.)) # b

self.assertRaises(ValueError, lgb.plot_split_value_histogram, gbm0, 0) # was not used in splitting

@unittest.skipIf(not MATPLOTLIB_INSTALLED or not GRAPHVIZ_INSTALLED, 'matplotlib or graphviz is not installed')
def test_plot_tree(self):
gbm = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
gbm.fit(self.X_train, self.y_train, verbose=False)

self.assertRaises(IndexError, lgb.plot_tree, gbm, tree_index=83)

ax = lgb.plot_tree(gbm, tree_index=3, figsize=(15, 8), show_info=['split_gain'])
self.assertIsInstance(ax, matplotlib.axes.Axes)
w, h = ax.axes.get_figure().get_size_inches()
self.assertEqual(int(w), 15)
self.assertEqual(int(h), 8)

@unittest.skipIf(not GRAPHVIZ_INSTALLED, 'graphviz is not installed')
def test_create_tree_digraph(self):
constraints = [-1, 1] * int(self.X_train.shape[1] / 2)
gbm = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True, monotone_constraints=constraints)
gbm.fit(self.X_train, self.y_train, verbose=False)

self.assertRaises(IndexError, lgb.create_tree_digraph, gbm, tree_index=83)

graph = lgb.create_tree_digraph(gbm, tree_index=3,
show_info=['split_gain', 'internal_value', 'internal_weight'],
name='Tree4', node_attr={'color': 'red'})
graph.render(view=False)
self.assertIsInstance(graph, graphviz.Digraph)
self.assertEqual(graph.name, 'Tree4')
self.assertEqual(graph.filename, 'Tree4.gv')
self.assertEqual(len(graph.node_attr), 1)
self.assertEqual(graph.node_attr['color'], 'red')
self.assertEqual(len(graph.graph_attr), 0)
self.assertEqual(len(graph.edge_attr), 0)
graph_body = ''.join(graph.body)
self.assertIn('leaf', graph_body)
self.assertIn('gain', graph_body)
self.assertIn('value', graph_body)
self.assertIn('weight', graph_body)
self.assertIn('#ffdddd', graph_body)
self.assertIn('#ddffdd', graph_body)
self.assertNotIn('data', graph_body)
self.assertNotIn('count', graph_body)

@unittest.skipIf(not MATPLOTLIB_INSTALLED, 'matplotlib is not installed')
def test_plot_metrics(self):
test_data = lgb.Dataset(self.X_test, self.y_test, reference=self.train_data)
self.params.update({"metric": {"binary_logloss", "binary_error"}})

evals_result0 = {}
gbm0 = lgb.train(self.params, self.train_data,
valid_sets=[self.train_data, test_data],
valid_names=['v1', 'v2'],
num_boost_round=10,
evals_result=evals_result0,
verbose_eval=False)
ax0 = lgb.plot_metric(evals_result0)
self.assertIsInstance(ax0, matplotlib.axes.Axes)
self.assertEqual(ax0.get_title(), 'Metric during training')
self.assertEqual(ax0.get_xlabel(), 'Iterations')
self.assertIn(ax0.get_ylabel(), {'binary_logloss', 'binary_error'})
ax0 = lgb.plot_metric(evals_result0, metric='binary_error')
ax0 = lgb.plot_metric(evals_result0, metric='binary_logloss', dataset_names=['v2'])

evals_result1 = {}
gbm1 = lgb.train(self.params, self.train_data,
num_boost_round=10,
evals_result=evals_result1,
verbose_eval=False)
self.assertRaises(ValueError, lgb.plot_metric, evals_result1)

gbm2 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
gbm2.fit(self.X_train, self.y_train, eval_set=[(self.X_test, self.y_test)], verbose=False)
ax2 = lgb.plot_metric(gbm2, title=None, xlabel=None, ylabel=None)
self.assertIsInstance(ax2, matplotlib.axes.Axes)
self.assertEqual(ax2.get_title(), '')
self.assertEqual(ax2.get_xlabel(), '')
self.assertEqual(ax2.get_ylabel(), '')
"num_leaves": 3}


@pytest.mark.skipif(not MATPLOTLIB_INSTALLED, reason='matplotlib is not installed')
def test_plot_importance(params, breast_cancer_split, train_data):
X_train, _, y_train, _ = breast_cancer_split

gbm0 = lgb.train(params, train_data, num_boost_round=10)
ax0 = lgb.plot_importance(gbm0)
assert isinstance(ax0, matplotlib.axes.Axes)
assert ax0.get_title() == 'Feature importance'
assert ax0.get_xlabel() == 'Feature importance'
assert ax0.get_ylabel() == 'Features'
assert len(ax0.patches) <= 30

gbm1 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
gbm1.fit(X_train, y_train)

ax1 = lgb.plot_importance(gbm1, color='r', title='t', xlabel='x', ylabel='y')
assert isinstance(ax1, matplotlib.axes.Axes)
assert ax1.get_title() == 't'
assert ax1.get_xlabel() == 'x'
assert ax1.get_ylabel() == 'y'
assert len(ax1.patches) <= 30
for patch in ax1.patches:
assert patch.get_facecolor() == (1., 0, 0, 1.) # red

ax2 = lgb.plot_importance(gbm0, color=['r', 'y', 'g', 'b'],
title=None, xlabel=None, ylabel=None)
assert isinstance(ax2, matplotlib.axes.Axes)
assert ax2.get_title() == ''
assert ax2.get_xlabel() == ''
assert ax2.get_ylabel() == ''
assert len(ax2.patches) <= 30
assert ax2.patches[0].get_facecolor() == (1., 0, 0, 1.) # r
assert ax2.patches[1].get_facecolor() == (.75, .75, 0, 1.) # y
assert ax2.patches[2].get_facecolor() == (0, .5, 0, 1.) # g
assert ax2.patches[3].get_facecolor() == (0, 0, 1., 1.) # b


@pytest.mark.skipif(not MATPLOTLIB_INSTALLED, reason='matplotlib is not installed')
def test_plot_split_value_histogram(params, breast_cancer_split, train_data):
X_train, _, y_train, _ = breast_cancer_split

gbm0 = lgb.train(params, train_data, num_boost_round=10)
ax0 = lgb.plot_split_value_histogram(gbm0, 27)
assert isinstance(ax0, matplotlib.axes.Axes)
assert ax0.get_title() == 'Split value histogram for feature with index 27'
assert ax0.get_xlabel() == 'Feature split value'
assert ax0.get_ylabel() == 'Count'
assert len(ax0.patches) <= 2

gbm1 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
gbm1.fit(X_train, y_train)

ax1 = lgb.plot_split_value_histogram(gbm1, gbm1.booster_.feature_name()[27], figsize=(10, 5),
title='Histogram for feature @index/name@ @feature@',
xlabel='x', ylabel='y', color='r')
assert isinstance(ax1, matplotlib.axes.Axes)
title = 'Histogram for feature name {}'.format(gbm1.booster_.feature_name()[27])
assert ax1.get_title() == title
assert ax1.get_xlabel() == 'x'
assert ax1.get_ylabel() == 'y'
assert len(ax1.patches) <= 2
for patch in ax1.patches:
assert patch.get_facecolor() == (1., 0, 0, 1.) # red

ax2 = lgb.plot_split_value_histogram(gbm0, 27, bins=10, color=['r', 'y', 'g', 'b'],
title=None, xlabel=None, ylabel=None)
assert isinstance(ax2, matplotlib.axes.Axes)
assert ax2.get_title() == ''
assert ax2.get_xlabel() == ''
assert ax2.get_ylabel() == ''
assert len(ax2.patches) == 10
assert ax2.patches[0].get_facecolor() == (1., 0, 0, 1.) # r
assert ax2.patches[1].get_facecolor() == (.75, .75, 0, 1.) # y
assert ax2.patches[2].get_facecolor() == (0, .5, 0, 1.) # g
assert ax2.patches[3].get_facecolor() == (0, 0, 1., 1.) # b

with pytest.raises(ValueError):
lgb.plot_split_value_histogram(gbm0, 0) # was not used in splitting


@pytest.mark.skipif(not MATPLOTLIB_INSTALLED or not GRAPHVIZ_INSTALLED,
reason='matplotlib or graphviz is not installed')
def test_plot_tree(breast_cancer_split):
X_train, _, y_train, _ = breast_cancer_split
gbm = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
gbm.fit(X_train, y_train, verbose=False)

with pytest.raises(IndexError):
lgb.plot_tree(gbm, tree_index=83)

ax = lgb.plot_tree(gbm, tree_index=3, figsize=(15, 8), show_info=['split_gain'])
assert isinstance(ax, matplotlib.axes.Axes)
w, h = ax.axes.get_figure().get_size_inches()
assert int(w) == 15
assert int(h) == 8


@pytest.mark.skipif(not GRAPHVIZ_INSTALLED, reason='graphviz is not installed')
def test_create_tree_digraph(breast_cancer_split):
X_train, _, y_train, _ = breast_cancer_split

constraints = [-1, 1] * int(X_train.shape[1] / 2)
gbm = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True, monotone_constraints=constraints)
gbm.fit(X_train, y_train, verbose=False)

with pytest.raises(IndexError):
lgb.create_tree_digraph(gbm, tree_index=83)

graph = lgb.create_tree_digraph(gbm, tree_index=3,
show_info=['split_gain', 'internal_value', 'internal_weight'],
name='Tree4', node_attr={'color': 'red'})
graph.render(view=False)
assert isinstance(graph, graphviz.Digraph)
assert graph.name == 'Tree4'
assert graph.filename == 'Tree4.gv'
assert len(graph.node_attr) == 1
assert graph.node_attr['color'] == 'red'
assert len(graph.graph_attr) == 0
assert len(graph.edge_attr) == 0
graph_body = ''.join(graph.body)
assert 'leaf' in graph_body
assert 'gain' in graph_body
assert 'value' in graph_body
assert 'weight' in graph_body
assert '#ffdddd' in graph_body
assert '#ddffdd' in graph_body
assert 'data' not in graph_body
assert 'count' not in graph_body


@pytest.mark.skipif(not MATPLOTLIB_INSTALLED, reason='matplotlib is not installed')
def test_plot_metrics(params, breast_cancer_split, train_data):
X_train, X_test, y_train, y_test = breast_cancer_split
test_data = lgb.Dataset(X_test, y_test, reference=train_data)
params.update({"metric": {"binary_logloss", "binary_error"}})

evals_result0 = {}
lgb.train(params, train_data,
valid_sets=[train_data, test_data],
valid_names=['v1', 'v2'],
num_boost_round=10,
evals_result=evals_result0,
verbose_eval=False)
ax0 = lgb.plot_metric(evals_result0)
assert isinstance(ax0, matplotlib.axes.Axes)
assert ax0.get_title() == 'Metric during training'
assert ax0.get_xlabel() == 'Iterations'
assert ax0.get_ylabel() in {'binary_logloss', 'binary_error'}
ax0 = lgb.plot_metric(evals_result0, metric='binary_error')
ax0 = lgb.plot_metric(evals_result0, metric='binary_logloss', dataset_names=['v2'])

evals_result1 = {}
lgb.train(params, train_data,
num_boost_round=10,
evals_result=evals_result1,
verbose_eval=False)
with pytest.raises(ValueError):
lgb.plot_metric(evals_result1)

gbm2 = lgb.LGBMClassifier(n_estimators=10, num_leaves=3, silent=True)
gbm2.fit(X_train, y_train, eval_set=[(X_test, y_test)], verbose=False)
ax2 = lgb.plot_metric(gbm2, title=None, xlabel=None, ylabel=None)
assert isinstance(ax2, matplotlib.axes.Axes)
assert ax2.get_title() == ''
assert ax2.get_xlabel() == ''
assert ax2.get_ylabel() == ''