Skip to content

Commit

Permalink
Merge pull request #27 from mrocklin/scheduler
Browse files Browse the repository at this point in the history
Refactor Scheduler to class
  • Loading branch information
mrocklin committed Nov 28, 2015
2 parents 7415fbf + 83ebcf3 commit fad4ce3
Show file tree
Hide file tree
Showing 8 changed files with 520 additions and 735 deletions.
1 change: 0 additions & 1 deletion distributed/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from .client import scatter, gather, delete, clear, rpc
from .utils import sync
from .nanny import Nanny
from .dask import get
from .executor import Executor, wait, as_completed

__version__ = '1.4.0'
119 changes: 0 additions & 119 deletions distributed/dask.py

This file was deleted.

113 changes: 32 additions & 81 deletions distributed/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
from .client import (WrappedKey, _gather, unpack_remotedata, pack_data,
scatter_to_workers)
from .core import read, write, connect, rpc, coerce_to_rpc
from .dask import scheduler, worker, delete
from .scheduler import Scheduler
from .sizeof import sizeof
from .utils import All, sync, funcname

Expand Down Expand Up @@ -162,22 +162,8 @@ def __init__(self, center, start=True, delete_batch_time=1, loop=None):
self.center = coerce_to_rpc(center)
self.futures = dict()
self.refcount = defaultdict(lambda: 0)
self.dask = dict()
self.restrictions = dict()
self.loop = loop or IOLoop()
self.report_queue = Queue()
self.scheduler_queue = Queue()
self._shutdown_event = Event()
self._delete_batch_time = delete_batch_time
self.ncores = dict()
self.nannies = dict()
self.who_has = defaultdict(set)
self.has_what = defaultdict(set)
self.waiting = {}
self.processing = {}
self.stacks = {}
self.held_data = set()
self.nbytes = dict()
self.scheduler = Scheduler(center, delete_batch_time=delete_batch_time)

if start:
self.start()
Expand All @@ -191,20 +177,25 @@ def start(self):
self._loop_thread.daemon = True
_global_executors.add(self)
self._loop_thread.start()
sync(self.loop, self._sync_center)
self.loop.add_callback(self._go)
while not len(self.stacks) == len(self.ncores):
time.sleep(0.01)
sync(self.loop, self._start)

@gen.coroutine
def _start(self):
yield self._sync_center()
self.loop.spawn_callback(self._go)
yield self.scheduler._sync_center()
self._scheduler_start_event = Event()
self.coroutines = [self.scheduler.start(), self.report()]
_global_executors.add(self)
while not len(self.stacks) == len(self.ncores):
yield gen.sleep(0.01)
yield self._scheduler_start_event.wait()
logger.debug("Started scheduling coroutines. Synchronized")

@property
def scheduler_queue(self):
return self.scheduler.scheduler_queue

@property
def report_queue(self):
return self.scheduler.report_queue

def __enter__(self):
if not self.loop._running:
self.start()
Expand Down Expand Up @@ -236,6 +227,8 @@ def report(self):
""" Listen to scheduler """
while True:
msg = yield self.report_queue.get()
if msg['op'] == 'start':
self._scheduler_start_event.set()
if msg['op'] == 'close':
break
if msg['op'] == 'key-in-memory':
Expand All @@ -254,15 +247,16 @@ def report(self):
self.futures[msg['key']]['event'].set()

@gen.coroutine
def _shutdown(self):
""" Send shutdown signal and wait until _go completes """
def _shutdown(self, fast=False):
""" Send shutdown signal and wait until scheduler completes """
self.loop.add_callback(self.report_queue.put_nowait,
{'op': 'close'})
self.loop.add_callback(self.scheduler_queue.put_nowait,
{'op': 'close'})
if self in _global_executors:
_global_executors.remove(self)
yield self._shutdown_event.wait()
if not fast:
yield self.coroutines

def shutdown(self):
""" Send shutdown signal and wait until scheduler terminates """
Expand All @@ -273,50 +267,6 @@ def shutdown(self):
if self in _global_executors:
_global_executors.remove(self)

@gen.coroutine
def _sync_center(self):
self.who_has.clear()
self.has_what.clear()
self.ncores.clear()
self.nannies.clear()

who_has, has_what, ncores, nannies = yield [self.center.who_has(),
self.center.has_what(),
self.center.ncores(),
self.center.nannies()]
logger.debug("Synchronize with center. Retrieve %d workers",
len(ncores))

self.who_has.update(who_has)
self.has_what.update(has_what)
self.ncores.update(ncores)
self.nannies.update(nannies)

@gen.coroutine
def _go(self):
""" Setup and run all other coroutines. Block until finished. """
worker_queues = {worker: Queue() for worker in self.ncores}
delete_queue = Queue()

for collection in [self.dask, self.nbytes, self.restrictions]:
collection.clear()

self.coroutines = ([
self.report(),
scheduler(self.scheduler_queue, self.report_queue, worker_queues,
delete_queue, who_has=self.who_has, has_what=self.has_what,
ncores=self.ncores, dsk=self.dask,
held_data=self.held_data, restrictions=self.restrictions,
waiting=self.waiting, stacks=self.stacks,
processing=self.processing, nbytes=self.nbytes),
delete(self.scheduler_queue, delete_queue,
self.center.ip, self.center.port, self._delete_batch_time)]
+ [worker(self.scheduler_queue, worker_queues[w], w, n)
for w, n in self.ncores.items()])

results = yield All(self.coroutines)
self._shutdown_event.set()

def submit(self, func, *args, **kwargs):
""" Submit a function application to the scheduler
Expand Down Expand Up @@ -498,11 +448,11 @@ def gather(self, futures):

@gen.coroutine
def _scatter(self, data, workers=None):
if not self.ncores:
if not self.scheduler.ncores:
raise ValueError("No workers yet found. "
"Try syncing with center.\n"
" e.sync_center()")
ncores = workers if workers is not None else self.ncores
ncores = workers if workers is not None else self.scheduler.ncores
remotes, who_has, nbytes = yield scatter_to_workers(
self.center, ncores, data)
if isinstance(remotes, list):
Expand All @@ -513,7 +463,7 @@ def _scatter(self, data, workers=None):
{'op': 'update-data',
'who-has': who_has,
'nbytes': nbytes})
while not all(k in self.who_has for k in who_has):
while not all(k in self.scheduler.who_has for k in who_has):
yield gen.sleep(0.001)
raise gen.Return(remotes)

Expand Down Expand Up @@ -593,31 +543,32 @@ def get(self, dsk, keys, **kwargs):
@gen.coroutine
def _restart(self):
logger.debug("Sending shutdown signal to workers")
for addr in self.nannies:
nannies = yield self.center.nannies()
for addr in nannies:
self.loop.add_callback(self.scheduler_queue.put_nowait,
{'op': 'worker-failed', 'worker': addr, 'heal': False})

logger.debug("Sending kill signal to nannies")
nannies = [rpc(ip=ip, port=n_port)
for (ip, w_port), n_port in self.nannies.items()]
for (ip, w_port), n_port in nannies.items()]
yield All([nanny.kill() for nanny in nannies])

while self.ncores:
while self.scheduler.ncores:
yield gen.sleep(0.01)

yield self._shutdown()
yield self._shutdown(fast=True)

events = [d['event'] for d in self.futures.values()]
self.futures.clear()
for e in events:
e.set()


yield All([nanny.instantiate(close=True) for nanny in nannies])

logger.info("Restarting executor")
self.report_queue = Queue()
self.scheduler_queue = Queue()
self.scheduler.report_queue = Queue()
self.scheduler.scheduler_queue = Queue()
self.scheduler.delete_queue = Queue()
yield self._start()

def restart(self):
Expand Down
Loading

0 comments on commit fad4ce3

Please sign in to comment.