Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generalize UCX errors on connect() and correct pytest fixtures #6434

Merged
merged 3 commits into from
May 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion distributed/comm/tests/test_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -563,7 +563,7 @@ async def client_communicate(key, delay=0):

@pytest.mark.gpu
@gen_test()
async def test_ucx_client_server():
async def test_ucx_client_server(ucx_loop):
pytest.importorskip("distributed.comm.ucx")
ucp = pytest.importorskip("ucp")

Expand Down
71 changes: 30 additions & 41 deletions distributed/comm/tests/test_ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,25 +22,7 @@
HOST = "127.0.0.1"


def handle_exception(loop, context):
msg = context.get("exception", context["message"])
print(msg)


# Let's make sure that UCX gets time to cancel
# progress tasks before closing the event loop.
@pytest.fixture()
def event_loop(scope="function"):
loop = asyncio.new_event_loop()
loop.set_exception_handler(handle_exception)
ucp.reset()
yield loop
ucp.reset()
loop.run_until_complete(asyncio.sleep(0))
loop.close()


def test_registered():
def test_registered(ucx_loop):
assert "ucx" in backends
backend = get_backend("ucx")
assert isinstance(backend, ucx.UCXBackend)
Expand All @@ -62,7 +44,7 @@ async def handle_comm(comm):


@gen_test()
async def test_ping_pong():
async def test_ping_pong(ucx_loop):
com, serv_com = await get_comm_pair()
msg = {"op": "ping"}
await com.write(msg)
Expand All @@ -80,7 +62,7 @@ async def test_ping_pong():


@gen_test()
async def test_comm_objs():
async def test_comm_objs(ucx_loop):
comm, serv_comm = await get_comm_pair()

scheme, loc = parse_address(comm.peer_address)
Expand All @@ -93,7 +75,7 @@ async def test_comm_objs():


@gen_test()
async def test_ucx_specific():
async def test_ucx_specific(ucx_loop):
"""
Test concrete UCX API.
"""
Expand Down Expand Up @@ -147,7 +129,7 @@ async def client_communicate(key, delay=0):


@gen_test()
async def test_ping_pong_data():
async def test_ping_pong_data(ucx_loop):
np = pytest.importorskip("numpy")

data = np.ones((10, 10))
Expand All @@ -170,7 +152,7 @@ async def test_ping_pong_data():


@gen_test()
async def test_ucx_deserialize():
async def test_ucx_deserialize(ucx_loop):
# Note we see this error on some systems with this test:
# `socket.gaierror: [Errno -5] No address associated with hostname`
# This may be due to a system configuration issue.
Expand All @@ -196,7 +178,7 @@ async def test_ucx_deserialize():
],
)
@gen_test()
async def test_ping_pong_cudf(g):
async def test_ping_pong_cudf(ucx_loop, g):
# if this test appears after cupy an import error arises
# *** ImportError: /usr/lib/x86_64-linux-gnu/libstdc++.so.6: version `CXXABI_1.3.11'
# not found (required by python3.7/site-packages/pyarrow/../../../libarrow.so.12)
Expand All @@ -221,7 +203,7 @@ async def test_ping_pong_cudf(g):

@pytest.mark.parametrize("shape", [(100,), (10, 10), (4947,)])
@gen_test()
async def test_ping_pong_cupy(shape):
async def test_ping_pong_cupy(ucx_loop, shape):
cupy = pytest.importorskip("cupy")
com, serv_com = await get_comm_pair()

Expand All @@ -240,7 +222,7 @@ async def test_ping_pong_cupy(shape):
@pytest.mark.slow
@pytest.mark.parametrize("n", [int(1e9), int(2.5e9)])
@gen_test()
async def test_large_cupy(n, cleanup):
async def test_large_cupy(ucx_loop, n, cleanup):
cupy = pytest.importorskip("cupy")
com, serv_com = await get_comm_pair()

Expand All @@ -257,7 +239,7 @@ async def test_large_cupy(n, cleanup):


@gen_test()
async def test_ping_pong_numba():
async def test_ping_pong_numba(ucx_loop):
np = pytest.importorskip("numpy")
numba = pytest.importorskip("numba")
import numba.cuda
Expand All @@ -276,7 +258,7 @@ async def test_ping_pong_numba():

@pytest.mark.parametrize("processes", [True, False])
@gen_test()
async def test_ucx_localcluster(processes, cleanup):
async def test_ucx_localcluster(ucx_loop, processes, cleanup):
async with LocalCluster(
protocol="ucx",
host=HOST,
Expand All @@ -297,7 +279,9 @@ async def test_ucx_localcluster(processes, cleanup):

@pytest.mark.slow
@gen_test(timeout=60)
async def test_stress():
async def test_stress(
ucx_loop,
):
da = pytest.importorskip("dask.array")

chunksize = "10 MB"
Expand All @@ -322,15 +306,19 @@ async def test_stress():


@gen_test()
async def test_simple():
async def test_simple(
ucx_loop,
):
async with LocalCluster(protocol="ucx", asynchronous=True) as cluster:
async with Client(cluster, asynchronous=True) as client:
assert cluster.scheduler_address.startswith("ucx://")
assert await client.submit(lambda x: x + 1, 10) == 11


@gen_test()
async def test_cuda_context():
async def test_cuda_context(
ucx_loop,
):
with dask.config.set({"distributed.comm.ucx.create-cuda-context": True}):
async with LocalCluster(
protocol="ucx", n_workers=1, asynchronous=True
Expand All @@ -344,7 +332,9 @@ async def test_cuda_context():


@gen_test()
async def test_transpose():
async def test_transpose(
ucx_loop,
):
da = pytest.importorskip("dask.array")

async with LocalCluster(protocol="ucx", asynchronous=True) as cluster:
Expand All @@ -358,7 +348,7 @@ async def test_transpose():

@pytest.mark.parametrize("port", [0, 1234])
@gen_test()
async def test_ucx_protocol(cleanup, port):
async def test_ucx_protocol(ucx_loop, cleanup, port):
async with Scheduler(protocol="ucx", port=port, dashboard_address=":0") as s:
assert s.address.startswith("ucx://")

Expand All @@ -367,10 +357,9 @@ async def test_ucx_protocol(cleanup, port):
not hasattr(ucp.exceptions, "UCXUnreachable"),
reason="Requires UCX-Py support for UCXUnreachable exception",
)
def test_ucx_unreachable():
if ucp.get_ucx_version() > (1, 12, 0):
with pytest.raises(OSError, match="Timed out trying to connect to"):
Client("ucx://255.255.255.255:12345", timeout=1)
else:
with pytest.raises(ucp.exceptions.UCXError, match="Destination is unreachable"):
Client("ucx://255.255.255.255:12345", timeout=1)
@gen_test()
async def test_ucx_unreachable(
ucx_loop,
):
with pytest.raises(OSError, match="Timed out trying to connect to"):
await Client("ucx://255.255.255.255:12345", timeout=1, asynchronous=True)
4 changes: 2 additions & 2 deletions distributed/comm/tests/test_ucx_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@


@gen_test()
async def test_ucx_config(cleanup):
async def test_ucx_config(ucx_loop, cleanup):
ucx = {
"nvlink": True,
"infiniband": True,
Expand Down Expand Up @@ -79,7 +79,7 @@ async def test_ucx_config(cleanup):
reruns=10,
reruns_delay=5,
)
def test_ucx_config_w_env_var(cleanup, loop):
def test_ucx_config_w_env_var(ucx_loop, cleanup, loop):
env = os.environ.copy()
env["DASK_RMM__POOL_SIZE"] = "1000.00 MB"

Expand Down
6 changes: 1 addition & 5 deletions distributed/comm/ucx.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,11 +397,7 @@ async def connect(self, address: str, deserialize=True, **connection_args) -> UC
init_once()
try:
ep = await ucp.create_endpoint(ip, port)
except (ucp.exceptions.UCXCloseError, ucp.exceptions.UCXCanceled,) + (
getattr(ucp.exceptions, "UCXConnectionReset", ()),
getattr(ucp.exceptions, "UCXNotConnected", ()),
getattr(ucp.exceptions, "UCXUnreachable", ()),
): # type: ignore
except ucp.exceptions.UCXBaseException:
raise CommClosedError("Connection closed before handshake completed")
return self.comm_class(
ep,
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ def raise_err():

@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
@gen_test()
async def test_nanny_closed_by_keyboard_interrupt(protocol):
async def test_nanny_closed_by_keyboard_interrupt(ucx_loop, protocol):
if protocol == "ucx": # Skip if UCX isn't available
pytest.importorskip("ucp")

Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,7 +1384,7 @@ async def test_interface_async(Worker):
@pytest.mark.gpu
@pytest.mark.parametrize("Worker", [Worker, Nanny])
@gen_test()
async def test_protocol_from_scheduler_address(Worker):
async def test_protocol_from_scheduler_address(ucx_loop, Worker):
pytest.importorskip("ucp")

async with Scheduler(protocol="ucx", dashboard_address=":0") as s:
Expand Down
36 changes: 36 additions & 0 deletions distributed/utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -2153,3 +2153,39 @@ def raises_with_cause(
assert re.search(
match_cause, str(exc.__cause__)
), f"Pattern ``{match_cause}`` not found in ``{exc.__cause__}``"


def ucx_exception_handler(loop, context):
"""UCX exception handler for `ucx_loop` during test.

Prints the exception and its message.

Parameters
----------
loop: object
Reference to the running event loop
context: dict
Dictionary containing exception details.
"""
msg = context.get("exception", context["message"])
print(msg)


# Let's make sure that UCX gets time to cancel
# progress tasks before closing the event loop.
@pytest.fixture(scope="function")
def ucx_loop():
"""Allows UCX to cancel progress tasks before closing event loop.

When UCX tasks are not completed in time (e.g., by unexpected Endpoint
closure), clean up tasks before closing the event loop to prevent unwanted
errors from being raised.
"""
ucp = pytest.importorskip("ucp")

loop = asyncio.new_event_loop()
loop.set_exception_handler(ucx_exception_handler)
ucp.reset()
yield loop
ucp.reset()
loop.close()