diff --git a/dask_cuda/tests/test_explicit_comms.py b/dask_cuda/tests/test_explicit_comms.py index 749c82b6..281a930e 100644 --- a/dask_cuda/tests/test_explicit_comms.py +++ b/dask_cuda/tests/test_explicit_comms.py @@ -8,6 +8,7 @@ import dask from dask import dataframe as dd from dask.dataframe.shuffle import partitioning_index +from dask.dataframe.utils import assert_eq from distributed import Client, get_worker from distributed.deploy.local import LocalCluster @@ -97,12 +98,8 @@ def check_partitions(df, npartitions): def _test_dataframe_shuffle(backend, protocol, n_workers): if backend == "cudf": cudf = pytest.importorskip("cudf") - from cudf.testing._utils import assert_eq - initialize(enable_tcp_over_ucx=True) else: - from dask.dataframe.utils import assert_eq - dask.config.update( dask.config.global_config, {"distributed.comm.ucx": get_ucx_config(enable_tcp_over_ucx=True),}, @@ -144,10 +141,7 @@ def _test_dataframe_shuffle(backend, protocol, n_workers): # Check the values of `ddf` (ignoring the row order) expected = df.sort_values("key") got = ddf.compute().sort_values("key") - if backend == "cudf": - assert_eq(got, expected) - else: - pd.testing.assert_frame_equal(got, expected) + assert_eq(got, expected) @pytest.mark.parametrize("nworkers", [1, 2, 3]) @@ -202,11 +196,9 @@ def test_dask_use_explicit_comms(): def _test_dataframe_shuffle_merge(backend, protocol, n_workers): if backend == "cudf": cudf = pytest.importorskip("cudf") - from cudf.testing._utils import assert_eq initialize(enable_tcp_over_ucx=True) else: - from dask.dataframe.utils import assert_eq dask.config.update( dask.config.global_config, @@ -243,10 +235,7 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers): ) with dask.config.set(explicit_comms=True): got = ddf1.merge(ddf2, on="key").set_index("key").compute() - if backend == "cudf": - assert_eq(got, expected) - else: - pd.testing.assert_frame_equal(got, expected) + assert_eq(got, expected) @pytest.mark.parametrize("nworkers", [1, 2, 4]) @@ -265,7 +254,6 @@ def test_dataframe_shuffle_merge(backend, protocol, nworkers): def _test_jit_unspill(protocol): import cudf - from cudf.testing._utils import assert_eq with dask_cuda.LocalCUDACluster( protocol=protocol,