Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tests: replacing the obsolete cudf.testing._utils.assert_eq calls #706

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 3 additions & 15 deletions dask_cuda/tests/test_explicit_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import dask
from dask import dataframe as dd
from dask.dataframe.shuffle import partitioning_index
from dask.dataframe.utils import assert_eq
from distributed import Client, get_worker
from distributed.deploy.local import LocalCluster

Expand Down Expand Up @@ -97,12 +98,8 @@ def check_partitions(df, npartitions):
def _test_dataframe_shuffle(backend, protocol, n_workers):
if backend == "cudf":
cudf = pytest.importorskip("cudf")
from cudf.testing._utils import assert_eq

initialize(enable_tcp_over_ucx=True)
else:
from dask.dataframe.utils import assert_eq

dask.config.update(
dask.config.global_config,
{"distributed.comm.ucx": get_ucx_config(enable_tcp_over_ucx=True),},
Expand Down Expand Up @@ -144,10 +141,7 @@ def _test_dataframe_shuffle(backend, protocol, n_workers):
# Check the values of `ddf` (ignoring the row order)
expected = df.sort_values("key")
got = ddf.compute().sort_values("key")
if backend == "cudf":
assert_eq(got, expected)
else:
pd.testing.assert_frame_equal(got, expected)
assert_eq(got, expected)


@pytest.mark.parametrize("nworkers", [1, 2, 3])
Expand Down Expand Up @@ -202,11 +196,9 @@ def test_dask_use_explicit_comms():
def _test_dataframe_shuffle_merge(backend, protocol, n_workers):
if backend == "cudf":
cudf = pytest.importorskip("cudf")
from cudf.testing._utils import assert_eq

initialize(enable_tcp_over_ucx=True)
else:
from dask.dataframe.utils import assert_eq

dask.config.update(
dask.config.global_config,
Expand Down Expand Up @@ -243,10 +235,7 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers):
)
with dask.config.set(explicit_comms=True):
got = ddf1.merge(ddf2, on="key").set_index("key").compute()
if backend == "cudf":
assert_eq(got, expected)
else:
pd.testing.assert_frame_equal(got, expected)
assert_eq(got, expected)


@pytest.mark.parametrize("nworkers", [1, 2, 4])
Expand All @@ -265,7 +254,6 @@ def test_dataframe_shuffle_merge(backend, protocol, nworkers):

def _test_jit_unspill(protocol):
import cudf
from cudf.testing._utils import assert_eq

with dask_cuda.LocalCUDACluster(
protocol=protocol,
Expand Down