Skip to content

Commit

Permalink
Change runtime.py to choose tasks with lowest (instead of highest) pr…
Browse files Browse the repository at this point in the history
…iority (#505)

Currently, the priority is set to the timestamp of the earliest undispatched task.
Choosing earliest tasks will reduce the maximum waiting time when queue is nonempty

Co-authored-by: Max Ryabinin <[email protected]>
Co-authored-by: Pavel Samygin <[email protected]>
(cherry picked from commit 6395e89)
  • Loading branch information
justheuristic authored and mryab committed Sep 13, 2022
1 parent a2f2407 commit 2ba000b
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions hivemind/moe/server/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,8 +143,8 @@ def iterate_minibatches_from_pools(self, timeout=None):
if self.SHUTDOWN_TRIGGER in ready_objects:
break # someone asked us to shutdown, break from the loop

logger.debug("Choosing the pool with highest priority")
pool = max(ready_objects, key=lambda pool: pool.priority)
logger.debug("Choosing the pool with first priority")
pool = min(ready_objects, key=lambda pool: pool.priority)

logger.debug(f"Loading batch from {pool.name}")
batch_index, batch_tensors = pool.load_batch_to_runtime(timeout, self.device)
Expand Down
4 changes: 2 additions & 2 deletions hivemind/moe/server/task_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class TaskPoolBase(mp.context.ForkProcess, metaclass=ABCMeta):
def __init__(self, process_func: callable, daemon=True, **kwargs):
super().__init__(daemon=daemon, **kwargs)
self.process_func = process_func
self._priority = mp.Value(ctypes.c_double, 1.0) # higher priority = the more urgent to process this pool
self._priority = mp.Value(ctypes.c_double, 1.0) # lower priority = the more urgent to process this pool

@abstractmethod
def run(self):
Expand Down Expand Up @@ -170,7 +170,7 @@ def _pool_input_loop(self, pending_batches: Dict[Any, List[Task]], *args, **kwar
for skip_i in range(prev_num_tasks):
finished_task_timestamp = (
self.undispatched_task_timestamps.get()
) # earlier timestamp = higher priority
) # earlier timestamp = smaller (better) priority, earlier processing
if skip_i == prev_num_tasks - 1:
self.priority = finished_task_timestamp

Expand Down

0 comments on commit 2ba000b

Please sign in to comment.