Skip to content

Commit

Permalink
revisit test
Browse files Browse the repository at this point in the history
  • Loading branch information
crusaderky committed Apr 27, 2022
1 parent 4f2558d commit 4f7b8e9
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions distributed/tests/test_worker_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import asyncio
import logging
import threading
from collections import Counter, UserDict
from time import sleep

Expand Down Expand Up @@ -649,10 +650,11 @@ def leak():
},
)
async def test_pause_while_spilling(c, s, a):
N = 50
N_PAUSE = 3
N_TOTAL = 5

def get_process_memory():
if len(a.data) < N:
if len(a.data) < N_PAUSE:
# Don't trigger spilling until after all tasks have completed
return 0
elif a.data.fast and not a.data.slow:
Expand All @@ -665,20 +667,31 @@ def get_process_memory():
a.monitor.get_process_memory = get_process_memory

class SlowSpill:
def __init__(self, _, paused: bool = False):
self.paused = paused
def __init__(self, _):
# Can't pickle a Semaphore, so instead of a default value, we create it
# here. Don't worry about race conditions; the worker is single-threaded.
if not hasattr(type(self), "sem"):
type(self).sem = threading.Semaphore(N_PAUSE)
# Block if there are N_PAUSE tasks in a.data.fast
self.sem.acquire()

def __reduce__(self):
paused = distributed.get_worker().status == Status.paused
if not paused:
sleep(0.1)
return SlowSpill, (None, paused)
self.sem.release()
return bool, (paused,)

futs = c.map(SlowSpill, range(N))
while len(a.data.slow) < N:
futs = c.map(SlowSpill, range(N_TOTAL))
while len(a.data.slow) < N_PAUSE + 1:
await asyncio.sleep(0.01)

assert a.status == Status.paused
assert any(sp.paused for sp in a.data.values())
# Worker should have become paused after the first `SlowSpill` was evicted, because
# the spill to disk took longer than the memory monitor interval.
assert len(a.data.fast) == 0
assert len(a.data.slow) == N_PAUSE + 1
assert sum(paused is True for paused in a.data.slow.values()) == N_PAUSE


@pytest.mark.slow
Expand Down

0 comments on commit 4f7b8e9

Please sign in to comment.