Skip to content

Commit

Permalink
Specify HLG layer name during delayed object creation in fit (#898)
Browse files Browse the repository at this point in the history
  • Loading branch information
ayushdg authored Jan 5, 2022
1 parent 8cb4c2f commit 5466bec
Show file tree
Hide file tree
Showing 6 changed files with 19 additions and 10 deletions.
1 change: 1 addition & 0 deletions dask_ml/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
DASK_2_26_0 = DASK_VERSION >= packaging.version.parse("2.26.0")
DASK_2_28_0 = DASK_VERSION > packaging.version.parse("2.27.0")
DASK_2021_02_0 = DASK_VERSION >= packaging.version.parse("2021.02.0")
DASK_2022_01_0 = DASK_VERSION > packaging.version.parse("2021.12.0")
DISTRIBUTED_2_5_0 = DISTRIBUTED_VERSION > packaging.version.parse("2.5.0")
DISTRIBUTED_2_11_0 = DISTRIBUTED_VERSION > packaging.version.parse("2.10.0") # dev
DISTRIBUTED_2021_02_0 = DISTRIBUTED_VERSION >= packaging.version.parse("2021.02.0")
Expand Down
8 changes: 7 additions & 1 deletion dask_ml/_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from dask.highlevelgraph import HighLevelGraph
from toolz import partial

from ._compat import DASK_2022_01_0

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -125,7 +127,11 @@ def fit(
if y is not None:
dependencies.append(y)
new_dsk = HighLevelGraph.from_collections(name, dsk, dependencies=dependencies)
value = Delayed((name, nblocks - 1), new_dsk)

if DASK_2022_01_0:
value = Delayed((name, nblocks - 1), new_dsk, layer=name)
else:
value = Delayed((name, nblocks - 1), new_dsk)

if compute:
return value.compute()
Expand Down
4 changes: 2 additions & 2 deletions dask_ml/model_selection/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import warnings
from collections import defaultdict
from distutils.version import LooseVersion
from threading import Lock
from timeit import default_timer

import numpy as np
import packaging.version
from dask.base import normalize_token
from scipy import sparse
from scipy.stats import rankdata
Expand All @@ -20,7 +20,7 @@

# Copied from scikit-learn/sklearn/utils/fixes.py, can be removed once we drop
# support for scikit-learn < 0.18.1 or numpy < 1.12.0.
if LooseVersion(np.__version__) < "1.12.0":
if packaging.version.parse(np.__version__) < packaging.version.parse("1.12.0"):

class MaskedArray(np.ma.MaskedArray):
# Before numpy 1.12, np.ma.MaskedArray object is not picklable
Expand Down
4 changes: 2 additions & 2 deletions dask_ml/model_selection/utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import copy
import warnings
from distutils.version import LooseVersion
from itertools import compress

import dask
import dask.array as da
import dask.dataframe as dd
import numpy as np
import packaging.version
import scipy.sparse as sp
from dask.base import tokenize
from dask.delayed import Delayed, delayed
from sklearn.utils.validation import _is_arraylike, indexable

from ..utils import _num_samples

if LooseVersion(dask.__version__) > "0.15.4":
if packaging.version.parse(dask.__version__) > packaging.version.parse("0.15.4"):
from dask.base import is_dask_collection
else:
from dask.base import Base
Expand Down
6 changes: 3 additions & 3 deletions dask_ml/preprocessing/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import multiprocessing
import numbers
from collections import OrderedDict
from distutils.version import LooseVersion
from typing import Any, List, Optional, Sequence, Union

import dask.array as da
import dask.dataframe as dd
import numpy as np
import packaging.version
import pandas as pd
import sklearn.preprocessing
from dask import compute
Expand All @@ -26,8 +26,8 @@
from .._typing import ArrayLike, DataFrameType, NDArrayOrScalar, SeriesType
from ..base import DaskMLBaseMixin

_PANDAS_VERSION = LooseVersion(pd.__version__)
_HAS_CTD = _PANDAS_VERSION >= "0.21.0"
_PANDAS_VERSION = packaging.version.parse(pd.__version__)
_HAS_CTD = _PANDAS_VERSION >= packaging.version.parse("0.21.0")
BOUNDS_THRESHOLD = 1e-7


Expand Down
6 changes: 4 additions & 2 deletions tests/linear_model/test_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ def test_fit(fit_intercept, solver):
)
def test_fit_solver(solver):
import dask_glm
from distutils.version import LooseVersion
import packaging.version

if LooseVersion(dask_glm.__version__) <= "0.2.0":
if packaging.version.parse(dask_glm.__version__) <= packaging.version.parse(
"0.2.0"
):
pytest.skip("FutureWarning for dask config.")

X, y = make_classification(n_samples=100, n_features=5, chunks=50)
Expand Down

0 comments on commit 5466bec

Please sign in to comment.