Skip to content

Commit

Permalink
Test that sync() propagates contextvars (dask#8354)
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky authored Nov 16, 2023
1 parent 6a379f8 commit 4fb2906
Showing 1 changed file with 36 additions and 18 deletions.
54 changes: 36 additions & 18 deletions distributed/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from array import array
from collections import deque
from concurrent.futures import Executor, Future, ThreadPoolExecutor
from contextvars import ContextVar
from time import sleep
from unittest import mock

Expand Down Expand Up @@ -97,35 +98,37 @@ async def inc(x):
assert end - start < 10


def test_sync(loop_in_thread):
async def f(x, y):
await asyncio.sleep(0.01)
return x, y

result = sync(loop_in_thread, f, 1, y=2)
assert result == (1, 2)


def test_sync_error(loop_in_thread):
loop = loop_in_thread
try:
result = sync(loop, throws, 1)
except Exception as exc:
f = exc
assert "hello" in str(exc)
tb = get_traceback()
L = traceback.format_tb(tb)
assert any("throws" in line for line in L)
with pytest.raises(RuntimeError, match="hello!") as exc:
sync(loop_in_thread, throws, 1)

L = traceback.format_tb(exc.value.__traceback__)
assert any("throws" in line for line in L)

def function1(x):
return function2(x)

def function2(x):
return throws(x)

try:
result = sync(loop, function1, 1)
except Exception as exc:
assert "hello" in str(exc)
tb = get_traceback()
L = traceback.format_tb(tb)
assert any("function1" in line for line in L)
assert any("function2" in line for line in L)
with pytest.raises(RuntimeError, match="hello!") as exc:
sync(loop_in_thread, function1, 1)

L = traceback.format_tb(exc.value.__traceback__)
assert any("function1" in line for line in L)
assert any("function2" in line for line in L)


def test_sync_timeout(loop_in_thread):
loop = loop_in_thread
with pytest.raises(TimeoutError):
sync(loop_in_thread, asyncio.sleep, 0.5, callback_timeout=0.05)

Expand All @@ -145,6 +148,21 @@ async def get_loop():
exc_info.match("IOLoop is clos(ed|ing)")


def test_sync_contextvars(loop_in_thread):
"""Test that sync() propagates contextvars - namely,
distributed.metrics.context_meter callbacks
"""
v = ContextVar("v", default=0)

async def f():
return v.get()

assert sync(loop_in_thread, f) == 0
tok = v.set(1)
assert sync(loop_in_thread, f) == 1
v.reset(tok)


def test_is_kernel():
pytest.importorskip("IPython")
assert is_kernel() is False
Expand Down

0 comments on commit 4fb2906

Please sign in to comment.