Skip to content

Commit

Permalink
test: replaced the obsolete cudf.testing._utils.assert_eq calls
Browse files Browse the repository at this point in the history
  • Loading branch information
madsbk committed Aug 18, 2021
1 parent 4c175f4 commit 336702c
Showing 1 changed file with 3 additions and 15 deletions.
18 changes: 3 additions & 15 deletions dask_cuda/tests/test_explicit_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import dask
from dask import dataframe as dd
from dask.dataframe.shuffle import partitioning_index
from dask.dataframe.utils import assert_eq
from distributed import Client, get_worker
from distributed.deploy.local import LocalCluster

Expand Down Expand Up @@ -97,12 +98,8 @@ def check_partitions(df, npartitions):
def _test_dataframe_shuffle(backend, protocol, n_workers):
if backend == "cudf":
cudf = pytest.importorskip("cudf")
from cudf.testing._utils import assert_eq

initialize(enable_tcp_over_ucx=True)
else:
from dask.dataframe.utils import assert_eq

dask.config.update(
dask.config.global_config,
{"distributed.comm.ucx": get_ucx_config(enable_tcp_over_ucx=True),},
Expand Down Expand Up @@ -144,10 +141,7 @@ def _test_dataframe_shuffle(backend, protocol, n_workers):
# Check the values of `ddf` (ignoring the row order)
expected = df.sort_values("key")
got = ddf.compute().sort_values("key")
if backend == "cudf":
assert_eq(got, expected)
else:
pd.testing.assert_frame_equal(got, expected)
assert_eq(got, expected)


@pytest.mark.parametrize("nworkers", [1, 2, 3])
Expand Down Expand Up @@ -202,11 +196,9 @@ def test_dask_use_explicit_comms():
def _test_dataframe_shuffle_merge(backend, protocol, n_workers):
if backend == "cudf":
cudf = pytest.importorskip("cudf")
from cudf.testing._utils import assert_eq

initialize(enable_tcp_over_ucx=True)
else:
from dask.dataframe.utils import assert_eq

dask.config.update(
dask.config.global_config,
Expand Down Expand Up @@ -243,10 +235,7 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers):
)
with dask.config.set(explicit_comms=True):
got = ddf1.merge(ddf2, on="key").set_index("key").compute()
if backend == "cudf":
assert_eq(got, expected)
else:
pd.testing.assert_frame_equal(got, expected)
assert_eq(got, expected)


@pytest.mark.parametrize("nworkers", [1, 2, 4])
Expand All @@ -265,7 +254,6 @@ def test_dataframe_shuffle_merge(backend, protocol, nworkers):

def _test_jit_unspill(protocol):
import cudf
from cudf.testing._utils import assert_eq

with dask_cuda.LocalCUDACluster(
protocol=protocol,
Expand Down

0 comments on commit 336702c

Please sign in to comment.