Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Specify HLG layer name during delayed object creation in fit #898

Merged
merged 3 commits into from
Jan 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions dask_ml/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
DASK_2_26_0 = DASK_VERSION >= packaging.version.parse("2.26.0")
DASK_2_28_0 = DASK_VERSION > packaging.version.parse("2.27.0")
DASK_2021_02_0 = DASK_VERSION >= packaging.version.parse("2021.02.0")
DASK_2022_01_0 = DASK_VERSION > packaging.version.parse("2021.12.0")
DISTRIBUTED_2_5_0 = DISTRIBUTED_VERSION > packaging.version.parse("2.5.0")
DISTRIBUTED_2_11_0 = DISTRIBUTED_VERSION > packaging.version.parse("2.10.0") # dev
DISTRIBUTED_2021_02_0 = DISTRIBUTED_VERSION >= packaging.version.parse("2021.02.0")
Expand Down
8 changes: 7 additions & 1 deletion dask_ml/_partial.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from dask.highlevelgraph import HighLevelGraph
from toolz import partial

from ._compat import DASK_2022_01_0

logger = logging.getLogger(__name__)


Expand Down Expand Up @@ -125,7 +127,11 @@ def fit(
if y is not None:
dependencies.append(y)
new_dsk = HighLevelGraph.from_collections(name, dsk, dependencies=dependencies)
value = Delayed((name, nblocks - 1), new_dsk)

if DASK_2022_01_0:
value = Delayed((name, nblocks - 1), new_dsk, layer=name)
else:
value = Delayed((name, nblocks - 1), new_dsk)

if compute:
return value.compute()
Expand Down
4 changes: 2 additions & 2 deletions dask_ml/model_selection/methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,11 @@

import warnings
from collections import defaultdict
from distutils.version import LooseVersion
from threading import Lock
from timeit import default_timer

import numpy as np
import packaging.version
from dask.base import normalize_token
from scipy import sparse
from scipy.stats import rankdata
Expand All @@ -20,7 +20,7 @@

# Copied from scikit-learn/sklearn/utils/fixes.py, can be removed once we drop
# support for scikit-learn < 0.18.1 or numpy < 1.12.0.
if LooseVersion(np.__version__) < "1.12.0":
if packaging.version.parse(np.__version__) < packaging.version.parse("1.12.0"):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These individual checks can probably be moved to the _compat module similar to how it's done in other parts of the codebase.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, given that today the minimum supported NumPy version for Dask is 1.18, we could probably just bump the minimum supported version for Dask-ML too (though let's handle that in a separate PR)


class MaskedArray(np.ma.MaskedArray):
# Before numpy 1.12, np.ma.MaskedArray object is not picklable
Expand Down
4 changes: 2 additions & 2 deletions dask_ml/model_selection/utils.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import copy
import warnings
from distutils.version import LooseVersion
from itertools import compress

import dask
import dask.array as da
import dask.dataframe as dd
import numpy as np
import packaging.version
import scipy.sparse as sp
from dask.base import tokenize
from dask.delayed import Delayed, delayed
from sklearn.utils.validation import _is_arraylike, indexable

from ..utils import _num_samples

if LooseVersion(dask.__version__) > "0.15.4":
if packaging.version.parse(dask.__version__) > packaging.version.parse("0.15.4"):
from dask.base import is_dask_collection
else:
from dask.base import Base
Expand Down
6 changes: 3 additions & 3 deletions dask_ml/preprocessing/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,12 @@
import multiprocessing
import numbers
from collections import OrderedDict
from distutils.version import LooseVersion
from typing import Any, List, Optional, Sequence, Union

import dask.array as da
import dask.dataframe as dd
import numpy as np
import packaging.version
import pandas as pd
import sklearn.preprocessing
from dask import compute
Expand All @@ -26,8 +26,8 @@
from .._typing import ArrayLike, DataFrameType, NDArrayOrScalar, SeriesType
from ..base import DaskMLBaseMixin

_PANDAS_VERSION = LooseVersion(pd.__version__)
_HAS_CTD = _PANDAS_VERSION >= "0.21.0"
_PANDAS_VERSION = packaging.version.parse(pd.__version__)
_HAS_CTD = _PANDAS_VERSION >= packaging.version.parse("0.21.0")
BOUNDS_THRESHOLD = 1e-7


Expand Down
6 changes: 4 additions & 2 deletions tests/linear_model/test_glm.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ def test_fit(fit_intercept, solver):
)
def test_fit_solver(solver):
import dask_glm
from distutils.version import LooseVersion
import packaging.version

if LooseVersion(dask_glm.__version__) <= "0.2.0":
if packaging.version.parse(dask_glm.__version__) <= packaging.version.parse(
"0.2.0"
):
pytest.skip("FutureWarning for dask config.")

X, y = make_classification(n_samples=100, n_features=5, chunks=50)
Expand Down