From 5466bec75bd602cd537cb0dd423dd88520fbf466 Mon Sep 17 00:00:00 2001 From: Ayush Dattagupta Date: Wed, 5 Jan 2022 11:27:23 -0800 Subject: [PATCH] Specify HLG layer name during delayed object creation in fit (#898) --- dask_ml/_compat.py | 1 + dask_ml/_partial.py | 8 +++++++- dask_ml/model_selection/methods.py | 4 ++-- dask_ml/model_selection/utils.py | 4 ++-- dask_ml/preprocessing/data.py | 6 +++--- tests/linear_model/test_glm.py | 6 ++++-- 6 files changed, 19 insertions(+), 10 deletions(-) diff --git a/dask_ml/_compat.py b/dask_ml/_compat.py index 8e65d9d91..f5bd18140 100644 --- a/dask_ml/_compat.py +++ b/dask_ml/_compat.py @@ -22,6 +22,7 @@ DASK_2_26_0 = DASK_VERSION >= packaging.version.parse("2.26.0") DASK_2_28_0 = DASK_VERSION > packaging.version.parse("2.27.0") DASK_2021_02_0 = DASK_VERSION >= packaging.version.parse("2021.02.0") +DASK_2022_01_0 = DASK_VERSION > packaging.version.parse("2021.12.0") DISTRIBUTED_2_5_0 = DISTRIBUTED_VERSION > packaging.version.parse("2.5.0") DISTRIBUTED_2_11_0 = DISTRIBUTED_VERSION > packaging.version.parse("2.10.0") # dev DISTRIBUTED_2021_02_0 = DISTRIBUTED_VERSION >= packaging.version.parse("2021.02.0") diff --git a/dask_ml/_partial.py b/dask_ml/_partial.py index 0495edf9e..943237336 100644 --- a/dask_ml/_partial.py +++ b/dask_ml/_partial.py @@ -10,6 +10,8 @@ from dask.highlevelgraph import HighLevelGraph from toolz import partial +from ._compat import DASK_2022_01_0 + logger = logging.getLogger(__name__) @@ -125,7 +127,11 @@ def fit( if y is not None: dependencies.append(y) new_dsk = HighLevelGraph.from_collections(name, dsk, dependencies=dependencies) - value = Delayed((name, nblocks - 1), new_dsk) + + if DASK_2022_01_0: + value = Delayed((name, nblocks - 1), new_dsk, layer=name) + else: + value = Delayed((name, nblocks - 1), new_dsk) if compute: return value.compute() diff --git a/dask_ml/model_selection/methods.py b/dask_ml/model_selection/methods.py index 16dd6c501..86f5e4785 100644 --- a/dask_ml/model_selection/methods.py +++ b/dask_ml/model_selection/methods.py @@ -2,11 +2,11 @@ import warnings from collections import defaultdict -from distutils.version import LooseVersion from threading import Lock from timeit import default_timer import numpy as np +import packaging.version from dask.base import normalize_token from scipy import sparse from scipy.stats import rankdata @@ -20,7 +20,7 @@ # Copied from scikit-learn/sklearn/utils/fixes.py, can be removed once we drop # support for scikit-learn < 0.18.1 or numpy < 1.12.0. -if LooseVersion(np.__version__) < "1.12.0": +if packaging.version.parse(np.__version__) < packaging.version.parse("1.12.0"): class MaskedArray(np.ma.MaskedArray): # Before numpy 1.12, np.ma.MaskedArray object is not picklable diff --git a/dask_ml/model_selection/utils.py b/dask_ml/model_selection/utils.py index a9b2259f7..db5f9e0f5 100644 --- a/dask_ml/model_selection/utils.py +++ b/dask_ml/model_selection/utils.py @@ -1,12 +1,12 @@ import copy import warnings -from distutils.version import LooseVersion from itertools import compress import dask import dask.array as da import dask.dataframe as dd import numpy as np +import packaging.version import scipy.sparse as sp from dask.base import tokenize from dask.delayed import Delayed, delayed @@ -14,7 +14,7 @@ from ..utils import _num_samples -if LooseVersion(dask.__version__) > "0.15.4": +if packaging.version.parse(dask.__version__) > packaging.version.parse("0.15.4"): from dask.base import is_dask_collection else: from dask.base import Base diff --git a/dask_ml/preprocessing/data.py b/dask_ml/preprocessing/data.py index bf609a0fe..9cccd6895 100644 --- a/dask_ml/preprocessing/data.py +++ b/dask_ml/preprocessing/data.py @@ -4,12 +4,12 @@ import multiprocessing import numbers from collections import OrderedDict -from distutils.version import LooseVersion from typing import Any, List, Optional, Sequence, Union import dask.array as da import dask.dataframe as dd import numpy as np +import packaging.version import pandas as pd import sklearn.preprocessing from dask import compute @@ -26,8 +26,8 @@ from .._typing import ArrayLike, DataFrameType, NDArrayOrScalar, SeriesType from ..base import DaskMLBaseMixin -_PANDAS_VERSION = LooseVersion(pd.__version__) -_HAS_CTD = _PANDAS_VERSION >= "0.21.0" +_PANDAS_VERSION = packaging.version.parse(pd.__version__) +_HAS_CTD = _PANDAS_VERSION >= packaging.version.parse("0.21.0") BOUNDS_THRESHOLD = 1e-7 diff --git a/tests/linear_model/test_glm.py b/tests/linear_model/test_glm.py index b9453d20e..c99e05192 100644 --- a/tests/linear_model/test_glm.py +++ b/tests/linear_model/test_glm.py @@ -64,9 +64,11 @@ def test_fit(fit_intercept, solver): ) def test_fit_solver(solver): import dask_glm - from distutils.version import LooseVersion + import packaging.version - if LooseVersion(dask_glm.__version__) <= "0.2.0": + if packaging.version.parse(dask_glm.__version__) <= packaging.version.parse( + "0.2.0" + ): pytest.skip("FutureWarning for dask config.") X, y = make_classification(n_samples=100, n_features=5, chunks=50)