Skip to content

Commit

Permalink
[tests][python] reduce unnecessary data loading in tests (#3486)
Browse files Browse the repository at this point in the history
* [ci] [python] reduce unnecessary data loading in tests

* add profiling files to gitignore

* just use cache()

* default on cache size

* patch lru_cache on Python 2.7

* linting

* reduce duplicated code

* missing warnings

* fix imports

* fix lru_cache backport

* missing kwargs

* Apply suggestions from code review

Co-authored-by: Nikita Titov <[email protected]>

* reduce duplicated code

* cache in test_plotting

Co-authored-by: Nikita Titov <[email protected]>
  • Loading branch information
jameslamb and StrikerRUS authored Oct 29, 2020
1 parent 5cc9e67 commit 03c4d45
Show file tree
Hide file tree
Showing 7 changed files with 59 additions and 7 deletions.
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,8 @@ htmlcov/
.coverage.*
.cache
nosetests.xml
prof/
*.prof
coverage.xml
*,cover
.hypothesis/
Expand Down
Empty file.
4 changes: 3 additions & 1 deletion tests/python_package_test/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
import numpy as np

from scipy import sparse
from sklearn.datasets import load_breast_cancer, dump_svmlight_file, load_svmlight_file
from sklearn.datasets import dump_svmlight_file, load_svmlight_file
from sklearn.model_selection import train_test_split

from .utils import load_breast_cancer


class TestBasic(unittest.TestCase):

Expand Down
6 changes: 4 additions & 2 deletions tests/python_package_test/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@
import lightgbm as lgb
import numpy as np
from scipy.sparse import csr_matrix, isspmatrix_csr, isspmatrix_csc
from sklearn.datasets import (load_boston, load_breast_cancer, load_digits,
load_iris, load_svmlight_file, make_multilabel_classification)
from sklearn.datasets import load_svmlight_file, make_multilabel_classification
from sklearn.metrics import log_loss, mean_absolute_error, mean_squared_error, roc_auc_score, average_precision_score
from sklearn.model_selection import train_test_split, TimeSeriesSplit, GroupKFold

Expand All @@ -20,6 +19,8 @@
except ImportError:
import pickle

from .utils import load_boston, load_breast_cancer, load_digits, load_iris


decreasing_generator = itertools.count(0, -1)

Expand Down Expand Up @@ -2524,6 +2525,7 @@ def test_average_precision_metric(self):
sklearn_ap = average_precision_score(y, pred)
self.assertAlmostEqual(ap, sklearn_ap)
# test that average precision is 1 where model predicts perfectly
y = y.copy()
y[:] = 1
lgb_X = lgb.Dataset(X, label=y)
lgb.train(params, lgb_X, num_boost_round=1, valid_sets=[lgb_X], evals_result=res)
Expand Down
3 changes: 2 additions & 1 deletion tests/python_package_test/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@

import lightgbm as lgb
from lightgbm.compat import MATPLOTLIB_INSTALLED, GRAPHVIZ_INSTALLED
from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split

if MATPLOTLIB_INSTALLED:
Expand All @@ -12,6 +11,8 @@
if GRAPHVIZ_INSTALLED:
import graphviz

from .utils import load_breast_cancer


class TestBasic(unittest.TestCase):

Expand Down
6 changes: 3 additions & 3 deletions tests/python_package_test/test_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,7 @@
import numpy as np
from sklearn import __version__ as sk_version
from sklearn.base import clone
from sklearn.datasets import (load_boston, load_breast_cancer, load_digits,
load_iris, load_linnerud, load_svmlight_file,
make_multilabel_classification)
from sklearn.datasets import load_svmlight_file, make_multilabel_classification
from sklearn.exceptions import SkipTestWarning
from sklearn.metrics import log_loss, mean_squared_error
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, train_test_split
Expand All @@ -22,6 +20,8 @@
check_parameters_default_constructible)
from sklearn.utils.validation import check_is_fitted

from .utils import load_boston, load_breast_cancer, load_digits, load_iris, load_linnerud


decreasing_generator = itertools.count(0, -1)

Expand Down
45 changes: 45 additions & 0 deletions tests/python_package_test/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# coding: utf-8
import sklearn.datasets

try:
from functools import lru_cache
except ImportError:
import warnings
warnings.warn("Could not import functools.lru_cache", RuntimeWarning)

def lru_cache(maxsize=None):
cache = {}

def _lru_wrapper(user_function):
def wrapper(*args, **kwargs):
arg_key = (args, tuple(kwargs.items()))
if arg_key not in cache:
cache[arg_key] = user_function(*args, **kwargs)
return cache[arg_key]
return wrapper
return _lru_wrapper


@lru_cache(maxsize=None)
def load_boston(**kwargs):
return sklearn.datasets.load_boston(**kwargs)


@lru_cache(maxsize=None)
def load_breast_cancer(**kwargs):
return sklearn.datasets.load_breast_cancer(**kwargs)


@lru_cache(maxsize=None)
def load_digits(**kwargs):
return sklearn.datasets.load_digits(**kwargs)


@lru_cache(maxsize=None)
def load_iris(**kwargs):
return sklearn.datasets.load_iris(**kwargs)


@lru_cache(maxsize=None)
def load_linnerud(**kwargs):
return sklearn.datasets.load_linnerud(**kwargs)

0 comments on commit 03c4d45

Please sign in to comment.