Skip to content

Commit

Permalink
Add scheduler parameter (dask#44)
Browse files Browse the repository at this point in the history
* Add `scheduler` parameter

Allow specifying the scheduler by name instead of passing in the `get`
function directly

* A few tweaks

- Copy of fix for pickling masked arrays
- Add distributed to travis

* Support n_jobs

Still needs tests.

* Support n_jobs parameter

- Add tests for n_jobs
- Add tests for scheduler parameter

* Support scheduler aliases
  • Loading branch information
jcrist authored Apr 18, 2017
1 parent 1213422 commit 36ef51b
Show file tree
Hide file tree
Showing 5 changed files with 168 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ install:
# Install dependencies
- conda create -n test-environment python=$PYTHON
- source activate test-environment
- conda install dask numpy scikit-learn=$SKLEARN cytoolz pytest
- conda install dask distributed numpy scikit-learn=$SKLEARN cytoolz pytest
- pip install -q graphviz flake8
- pip install --no-deps -e .

Expand Down
23 changes: 19 additions & 4 deletions dask_searchcv/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import warnings
from collections import defaultdict
from threading import Lock
from distutils.version import LooseVersion

import numpy as np
from toolz import pluck
Expand All @@ -17,10 +18,24 @@

from .utils import copy_estimator

try:
from sklearn.utils.fixes import MaskedArray
except: # pragma: no cover
from numpy.ma import MaskedArray
# Copied from scikit-learn/sklearn/utils/fixes.py, can be removed once we drop
# support for scikit-learn < 0.18.1 or numpy < 1.12.0.
if LooseVersion(np.__version__) < '1.12.0':
class MaskedArray(np.ma.MaskedArray):
# Before numpy 1.12, np.ma.MaskedArray object is not picklable
# This fix is needed to make our model_selection.GridSearchCV
# picklable as the ``cv_results_`` param uses MaskedArray
def __getstate__(self):
"""Return the internal state of the masked array, for pickling
purposes.
"""
cf = 'CF'[self.flags.fnc]
data_state = super(np.ma.MaskedArray, self).__reduce__()[2]
return data_state + (np.ma.getmaskarray(self).tostring(cf),
self._fill_value)
else:
from numpy.ma import MaskedArray # noqa

# A singleton to indicate a missing parameter
MISSING = type('MissingParameter', (object,),
Expand Down
102 changes: 85 additions & 17 deletions dask_searchcv/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from operator import getitem
from collections import defaultdict
from itertools import repeat
from multiprocessing import cpu_count
import numbers

import numpy as np
Expand Down Expand Up @@ -615,21 +616,76 @@ def compute_n_splits(cv, X, y=None, groups=None):
return delayed(cv).get_n_splits(X, y, groups).compute()


def _normalize_n_jobs(n_jobs):
if not isinstance(n_jobs, int):
raise TypeError("n_jobs should be an int, got %s" % n_jobs)
if n_jobs == -1:
n_jobs = None # Scheduler default is use all cores
elif n_jobs < -1:
n_jobs = cpu_count() + 1 + n_jobs
return n_jobs


_scheduler_aliases = {'sync': 'synchronous',
'sequential': 'synchronous',
'threaded': 'threading'}


def _normalize_scheduler(scheduler, n_jobs, loop=None):
# Default
if scheduler is None:
scheduler = dask.context._globals.get('get')
if scheduler is None:
scheduler = dask.get if n_jobs == 1 else threaded_get
return scheduler

# Get-functions
if callable(scheduler):
return scheduler

# Support name aliases
if isinstance(scheduler, str):
scheduler = _scheduler_aliases.get(scheduler, scheduler)

if scheduler in ('threading', 'multiprocessing') and n_jobs == 1:
scheduler = dask.get
elif scheduler == 'threading':
scheduler = threaded_get
elif scheduler == 'multiprocessing':
from dask.multiprocessing import get as scheduler
elif scheduler == 'synchronous':
scheduler = dask.get
else:
try:
from dask.distributed import Client
# We pass loop to make testing possible, not needed for normal use
return Client(scheduler, set_as_default=False, loop=loop).get
except Exception as e:
msg = ("Failed to initialize scheduler from parameter %r. "
"This could be due to a typo, or a failure to initialize "
"the distributed scheduler. Original error is below:\n\n"
"%r" % (scheduler, e))
# Re-raise outside the except to provide a cleaner error message
raise ValueError(msg)
return scheduler


class DaskBaseSearchCV(BaseEstimator, MetaEstimatorMixin):
"""Base class for hyper parameter search with cross-validation."""

def __init__(self, estimator, scoring=None, iid=True, refit=True, cv=None,
error_score='raise', return_train_score=True, cache_cv=True,
get=None):
error_score='raise', return_train_score=True, scheduler=None,
n_jobs=-1, cache_cv=True):
self.scoring = scoring
self.estimator = estimator
self.iid = iid
self.refit = refit
self.cv = cv
self.error_score = error_score
self.return_train_score = return_train_score
self.scheduler = scheduler
self.n_jobs = n_jobs
self.cache_cv = cache_cv
self.get = get

@property
def _estimator_type(self):
Expand Down Expand Up @@ -739,8 +795,10 @@ def fit(self, X, y=None, groups=None, **fit_params):
self.dask_graph_ = dsk
self.n_splits_ = n_splits

get = self.get or dask.context._globals.get('get') or threaded_get
out = get(dsk, keys)
n_jobs = _normalize_n_jobs(self.n_jobs)
scheduler = _normalize_scheduler(self.scheduler, n_jobs)

out = scheduler(dsk, keys, num_workers=n_jobs)

self.cv_results_ = results = out[0]
self.best_index_ = np.flatnonzero(results["rank_test_score"] == 1)[0]
Expand Down Expand Up @@ -831,10 +889,18 @@ def visualize(self, filename='mydask', format=None, **kwargs):
If ``'False'``, the ``cv_results_`` attribute will not include training
scores.
get : None (default) or scheduler get function
The dask scheduler ``get`` function to use. Default is to use the
global scheduler if set, and fallback to the threaded scheduler
otherwise.
scheduler : string, callable, or None, default=None
The dask scheduler to use. Default is to use the global scheduler if set,
and fallback to the threaded scheduler otherwise. To use a different
scheduler, specify it by name (either "threading", "multiprocessing",
or "synchronous") or provide the scheduler ``get`` function. Other
arguments are assumed to be the address of a distributed scheduler,
and passed to ``dask.distributed.Client``.
n_jobs : int, default=-1
Number of jobs to run in parallel. Ignored for the synchronous and
distributed schedulers. If ``n_jobs == -1`` [default] all cpus are used.
For ``n_jobs < -1``, ``(n_cpus + 1 + n_jobs)`` are used.
cache_cv : bool, default=True
Whether to extract each train/test subset at most once in each worker
Expand Down Expand Up @@ -959,8 +1025,8 @@ def visualize(self, filename='mydask', format=None, **kwargs):
kernel=..., max_iter=-1, probability=False,
random_state=..., shrinking=..., tol=...,
verbose=...),
get=..., iid=..., param_grid=..., refit=...,
return_train_score=..., scoring=...)
iid=..., n_jobs=..., param_grid=..., refit=..., return_train_score=...,
scheduler=..., scoring=...)
>>> sorted(clf.cv_results_.keys()) # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
['mean_test_score', 'mean_train_score', 'param_C', 'param_kernel',...
'params', 'rank_test_score', 'split0_test_score', 'split0_train_score',...
Expand All @@ -978,11 +1044,12 @@ class GridSearchCV(DaskBaseSearchCV):

def __init__(self, estimator, param_grid, scoring=None, iid=True,
refit=True, cv=None, error_score='raise',
return_train_score=True, get=None, cache_cv=True):
return_train_score=True, scheduler=None, n_jobs=-1,
cache_cv=True):
super(GridSearchCV, self).__init__(estimator=estimator,
scoring=scoring, iid=iid, refit=refit, cv=cv,
error_score=error_score, return_train_score=return_train_score,
get=get, cache_cv=cache_cv)
scheduler=scheduler, n_jobs=n_jobs, cache_cv=cache_cv)

_check_param_grid(param_grid)
self.param_grid = param_grid
Expand Down Expand Up @@ -1038,8 +1105,9 @@ def _get_param_iterator(self):
kernel=..., max_iter=..., probability=...,
random_state=..., shrinking=..., tol=...,
verbose=...),
get=..., iid=..., n_iter=..., param_distributions=...,
random_state=..., refit=..., return_train_score=..., scoring=...)
iid=..., n_iter=..., n_jobs=..., param_distributions=...,
random_state=..., refit=..., return_train_score=...,
scheduler=..., scoring=...)
>>> sorted(clf.cv_results_.keys()) # doctest: +NORMALIZE_WHITESPACE +ELLIPSIS
['mean_test_score', 'mean_train_score', 'param_C', 'param_kernel',...
'params', 'rank_test_score', 'split0_test_score', 'split0_train_score',...
Expand All @@ -1058,12 +1126,12 @@ class RandomizedSearchCV(DaskBaseSearchCV):
def __init__(self, estimator, param_distributions, n_iter=10,
random_state=None, scoring=None, iid=True, refit=True,
cv=None, error_score='raise', return_train_score=True,
get=None, cache_cv=True):
scheduler=None, n_jobs=-1, cache_cv=True):

super(RandomizedSearchCV, self).__init__(estimator=estimator,
scoring=scoring, iid=iid, refit=refit, cv=cv,
error_score=error_score, return_train_score=return_train_score,
get=get, cache_cv=cache_cv)
scheduler=scheduler, n_jobs=n_jobs, cache_cv=cache_cv)

self.param_distributions = param_distributions
self.n_iter = n_iter
Expand Down
63 changes: 59 additions & 4 deletions dask_searchcv/tests/test_model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import os
import pickle
from itertools import product
from multiprocessing import cpu_count

import pytest
import numpy as np
Expand All @@ -13,6 +14,7 @@
from dask.base import tokenize
from dask.callbacks import Callback
from dask.delayed import delayed
from dask.threaded import get as get_threading
from dask.utils import tmpdir

from sklearn.datasets import make_classification, load_iris
Expand All @@ -37,12 +39,21 @@
from sklearn.svm import SVC

import dask_searchcv as dcv
from dask_searchcv.model_selection import compute_n_splits, check_cv
from dask_searchcv.model_selection import (compute_n_splits, check_cv,
_normalize_n_jobs, _normalize_scheduler)
from dask_searchcv.methods import CVCache
from dask_searchcv.utils_test import (FailingClassifier, MockClassifier,
ScalingTransformer, CheckXClassifier,
ignore_warnings)

try:
from distributed import Client
from distributed.utils_test import cluster, loop
has_distributed = True
except:
loop = pytest.fixture(lambda: None)
has_distributed = False


class assert_dask_compute(Callback):
def __init__(self, compute=False):
Expand Down Expand Up @@ -315,7 +326,7 @@ def test_pipeline_feature_union():

gs = GridSearchCV(pipe, param_grid=param_grid)
gs.fit(X, y)
dgs = dcv.GridSearchCV(pipe, param_grid=param_grid, get=dask.get)
dgs = dcv.GridSearchCV(pipe, param_grid=param_grid, scheduler='sync')
dgs.fit(X, y)

# Check best params match
Expand Down Expand Up @@ -359,7 +370,7 @@ def test_pipeline_sub_estimators():

gs = GridSearchCV(pipe, param_grid=param_grid)
gs.fit(X, y)
dgs = dcv.GridSearchCV(pipe, param_grid=param_grid, get=dask.get)
dgs = dcv.GridSearchCV(pipe, param_grid=param_grid, scheduler='sync')
dgs.fit(X, y)

# Check best params match
Expand Down Expand Up @@ -529,7 +540,7 @@ def test_cache_cv():
X, y = make_classification(n_samples=100, n_features=10, random_state=0)
X2 = X.view(CountTakes)
gs = dcv.GridSearchCV(MockClassifier(), {'foo_param': [0, 1, 2]},
cv=3, cache_cv=False, get=dask.get)
cv=3, cache_cv=False, scheduler='sync')
gs.fit(X2, y)
assert X2.count == 2 * 3 * 3 # (1 train + 1 test) * n_params * n_splits

Expand Down Expand Up @@ -557,3 +568,47 @@ def test_CVCache_serializable():
assert cache2.pairwise == cache.pairwise
assert all((cache2.splits[i][j] == cache.splits[i][j]).all()
for i in range(2) for j in range(2))


def test_normalize_n_jobs():
assert _normalize_n_jobs(-1) is None
assert _normalize_n_jobs(-2) == cpu_count() - 1
with pytest.raises(TypeError):
_normalize_n_jobs('not an integer')


@pytest.mark.parametrize('scheduler,n_jobs,get',
[(None, 4, get_threading),
('threading', 4, get_threading),
('threaded', 4, get_threading),
('threading', 1, dask.get),
('sequential', 4, dask.get),
('synchronous', 4, dask.get),
('sync', 4, dask.get),
('multiprocessing', 4, None),
(dask.get, 4, dask.get)])
def test_scheduler_param(scheduler, n_jobs, get):
if scheduler == 'multiprocessing':
mp = pytest.importorskip('dask.multiprocessing')
get = mp.get

assert _normalize_scheduler(scheduler, n_jobs) is get

X, y = make_classification(n_samples=100, n_features=10, random_state=0)
gs = dcv.GridSearchCV(MockClassifier(), {'foo_param': [0, 1, 2]}, cv=3,
scheduler=scheduler, n_jobs=n_jobs)
gs.fit(X, y)


@pytest.mark.skipif('not has_distributed')
def test_scheduler_param_distributed(loop):
X, y = make_classification(n_samples=100, n_features=10, random_state=0)
gs = dcv.GridSearchCV(MockClassifier(), {'foo_param': [0, 1, 2]}, cv=3)
with cluster() as (s, [a, b]):
with Client(s['address'], loop=loop):
gs.fit(X, y)


def test_scheduler_param_bad(loop):
with pytest.raises(ValueError):
_normalize_scheduler('threeding', 4, loop)
11 changes: 4 additions & 7 deletions dask_searchcv/tests/test_model_selection_sklearn.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@

import pickle
import pytest
from distutils.version import LooseVersion

import dask
import dask.array as da
Expand All @@ -14,7 +13,6 @@
import scipy.sparse as sp
from scipy.stats import expon

import sklearn
from sklearn.base import BaseEstimator
from sklearn.cluster import KMeans
from sklearn.datasets import (make_classification, make_blobs,
Expand Down Expand Up @@ -112,11 +110,12 @@ def test_grid_search_no_score():
# wrong results. This only happens with threads, not processes/sync.
# For now, we'll fit using the sync scheduler.
grid_search = dcv.GridSearchCV(clf, {'C': Cs}, scoring='accuracy',
get=dask.get)
scheduler='sync')
grid_search.fit(X, y)

grid_search_no_score = dcv.GridSearchCV(clf_no_score, {'C': Cs},
scoring='accuracy', get=dask.get)
scoring='accuracy',
scheduler='sync')
# smoketest grid search
grid_search_no_score.fit(X, y)

Expand Down Expand Up @@ -782,7 +781,7 @@ def test_grid_search_correct_score_results():
# in wrong results. This only happens with threads, not processes/sync.
# For now, we'll fit using the sync scheduler.
grid_search = dcv.GridSearchCV(clf, {'C': Cs}, scoring=score,
cv=n_splits, get=dask.get)
cv=n_splits, scheduler='sync')
cv_results = grid_search.fit(X, y).cv_results_

# Test scorer names
Expand Down Expand Up @@ -810,8 +809,6 @@ def test_grid_search_correct_score_results():
assert_almost_equal(correct_score, cv_scores[i])


@pytest.mark.skipif(LooseVersion(sklearn.__version__) < '0.18.1',
reason="Pickle of masked-arrays broken in 0.18.0")
def test_pickle():
# Test that a fit search can be pickled
clf = MockClassifier()
Expand Down

0 comments on commit 36ef51b

Please sign in to comment.