Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make proxy tests with LocalCUDACluster asynchronous #1084

Merged
merged 5 commits into from
Jan 16, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 31 additions & 20 deletions dask_cuda/tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,10 @@
from dask.sizeof import sizeof
from distributed import Client
from distributed.protocol.serialize import deserialize, serialize
from distributed.utils_test import gen_test

import dask_cuda
from dask_cuda import proxy_object
from dask_cuda import LocalCUDACluster, proxy_object
from dask_cuda.disk_io import SpillToDiskFile
from dask_cuda.proxify_device_objects import proxify_device_objects
from dask_cuda.proxify_host_file import ProxifyHostFile
Expand Down Expand Up @@ -282,7 +283,8 @@ def test_fixed_attribute_name():


@pytest.mark.parametrize("jit_unspill", [True, False])
def test_spilling_local_cuda_cluster(jit_unspill):
@gen_test(timeout=20)
async def test_spilling_local_cuda_cluster(jit_unspill):
"""Testing spilling of a proxied cudf dataframe in a local cuda cluster"""
cudf = pytest.importorskip("cudf")
dask_cudf = pytest.importorskip("dask_cudf")
Expand All @@ -299,14 +301,17 @@ def task(x):
return x

# Notice, setting `device_memory_limit=1B` to trigger spilling
with dask_cuda.LocalCUDACluster(
n_workers=1, device_memory_limit="1B", jit_unspill=jit_unspill
async with LocalCUDACluster(
n_workers=1,
device_memory_limit="1B",
jit_unspill=jit_unspill,
asynchronous=True,
) as cluster:
with Client(cluster):
async with Client(cluster, asynchronous=True) as client:
df = cudf.DataFrame({"a": range(10)})
ddf = dask_cudf.from_cudf(df, npartitions=1)
ddf = ddf.map_partitions(task, meta=df.head())
got = ddf.compute()
got = await client.compute(ddf)
if isinstance(got, pandas.Series):
pytest.xfail(
"BUG fixed by <https://github.com/rapidsai/dask-cuda/pull/451>"
Expand Down Expand Up @@ -395,7 +400,8 @@ def _pxy_deserialize(self):

@pytest.mark.parametrize("send_serializers", [None, ("dask", "pickle"), ("cuda",)])
@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
def test_communicating_proxy_objects(protocol, send_serializers):
@gen_test(timeout=20)
async def test_communicating_proxy_objects(protocol, send_serializers):
"""Testing serialization of cuDF dataframe when communicating"""
cudf = pytest.importorskip("cudf")

Expand All @@ -413,10 +419,13 @@ def task(x):
else:
assert serializers_used == "dask"

with dask_cuda.LocalCUDACluster(
n_workers=1, protocol=protocol, enable_tcp_over_ucx=protocol == "ucx"
async with dask_cuda.LocalCUDACluster(
n_workers=1,
protocol=protocol,
enable_tcp_over_ucx=protocol == "ucx",
asynchronous=True,
) as cluster:
with Client(cluster) as client:
async with Client(cluster, asynchronous=True) as client:
df = cudf.DataFrame({"a": range(10)})
df = proxy_object.asproxy(
df, serializers=send_serializers, subclass=_PxyObjTest
Expand All @@ -429,14 +438,14 @@ def task(x):
df._pxy_get().assert_on_deserializing = False
else:
df._pxy_get().assert_on_deserializing = True
df = client.scatter(df)
client.submit(task, df).result()
client.shutdown() # Avoids a UCX shutdown error
df = await client.scatter(df)
await client.submit(task, df)


@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
@pytest.mark.parametrize("shared_fs", [True, False])
def test_communicating_disk_objects(protocol, shared_fs):
@gen_test(timeout=20)
async def test_communicating_disk_objects(protocol, shared_fs):
"""Testing disk serialization of cuDF dataframe when communicating"""
cudf = pytest.importorskip("cudf")
ProxifyHostFile._spill_to_disk.shared_filesystem = shared_fs
Expand All @@ -450,16 +459,18 @@ def task(x):
else:
assert serializer_used == "dask"

with dask_cuda.LocalCUDACluster(
n_workers=1, protocol=protocol, enable_tcp_over_ucx=protocol == "ucx"
async with dask_cuda.LocalCUDACluster(
n_workers=1,
protocol=protocol,
enable_tcp_over_ucx=protocol == "ucx",
asynchronous=True,
) as cluster:
with Client(cluster) as client:
async with Client(cluster, asynchronous=True) as client:
df = cudf.DataFrame({"a": range(10)})
df = proxy_object.asproxy(df, serializers=("disk",), subclass=_PxyObjTest)
df._pxy_get().assert_on_deserializing = False
df = client.scatter(df)
client.submit(task, df).result()
client.shutdown() # Avoids a UCX shutdown error
df = await client.scatter(df)
await client.submit(task, df)


@pytest.mark.parametrize("array_module", ["numpy", "cupy"])
Expand Down