Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Faster beam search part2/2 #101

Closed
wants to merge 16 commits into from
Closed
171 changes: 123 additions & 48 deletions hivemind/dht/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,15 +12,17 @@
- [1] Maymounkov P., Mazieres D. (2002) Kademlia: A Peer-to-Peer Information System Based on the XOR Metric.
- [2] https://github.com/bmuller/kademlia , Brian, if you're reading this: THANK YOU! you're awesome :)
"""
import time
import asyncio
import ctypes
import multiprocessing as mp
import warnings
from collections import deque, OrderedDict
from concurrent.futures import ThreadPoolExecutor
from typing import List, Tuple, Optional, Sequence, OrderedDict as TOrderedDict, Union, Awaitable
from typing import List, Tuple, Optional, Sequence, OrderedDict as TOrderedDict, Union, Awaitable, Dict, Deque, Set

import uvloop
import numpy as np

from hivemind.client import RemoteExpert
from hivemind.dht.node import DHTNode, DHTID, DHTExpiration
Expand Down Expand Up @@ -66,6 +68,7 @@ class DHT(mp.Process):
indices from the previous step. Finally, MoE will use DHT.get_experts(uids: List[str]) search for specific experts.
This beam search explores one additional dimension per step and finds k best experts from across the DHT
in O(k / s * log(N)) average time where s is grid sparsity rate and N is the total number of experts.
TODO(jheuristic) replace _first_k_active with beam search description!
"""

UID_DELIMITER = '.' # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix
Expand Down Expand Up @@ -182,60 +185,132 @@ async def _declare_experts(self, node: DHTNode, uids: List[str], endpoint: Endpo
if future is not None:
future.set_result([store_ok[key] for key in data_to_store.keys()])

def first_k_active(
self, uid_prefixes: List[str], k: int, max_prefetch: int = 1, chunk_size: Optional[int] = None,
return_future=False) -> Union[TOrderedDict[str, RemoteExpert], Awaitable[TOrderedDict[str, RemoteExpert]]]:
def find_best_experts(
self, prefix: str, grid_scores: List[np.ndarray], k_best: int, *, time_budget: float = float('inf'),
return_future=False, **kwargs) -> Union[List[RemoteExpert], MPFuture]:
"""
Find and return k_best active experts with highest scores, use both local cache and DHT

:param prefix: common prefix for all expert uids in grid
:param grid_scores: scores predicted for each dimension in the grid,
:type grid_scores: model scores for each grid dimension, list of arrays of shape grid_size[i]
:param k_best: how many best experts should beam search return
:param time_budget: how much time beam_search is can spend on queries to other peers (default = unlimited)
After time_budget is reached, beam search won't search for more experts and instead fall back on local cache
Please note that any queries that fall outside the budget will still be performed in background and cached
for subsequent iterations as long as DHTNode.cache_locally is True
:param return_future: if set to True, returns MPFuture that can be awaited to get the actual result
:param kwargs: extra keyword parameters passed to self.dht.first_k_active
:returns: a list that contains *up to* k_best RemoteExpert instances
"""
grid_scores = list(map(np.asanyarray, grid_scores))
assert all(dim_scores.ndim == 1 for dim_scores in grid_scores)

future, _future = MPFuture.make_pair()
self.pipe.send(('_find_best_experts', [], dict(prefix=prefix, grid_scores=grid_scores, k_best=k_best,
time_budget=time_budget, future=_future, **kwargs)))
return future if return_future else future.result()

async def _find_best_experts(
self, node: DHTNode, prefix: str, grid_scores: List[np.ndarray], k_best: int, time_budget: float,
prefetch_rate: int = 1, future: MPFuture = None, **kwargs) -> List[RemoteExpert]:
deadline_time = time.perf_counter() + time_budget
beam_experts: List[RemoteExpert] = []
beam: List[str] = [prefix]
beam_scores = np.zeros(1)

for dim_index, dim_scores in enumerate(grid_scores):
# create all possible successors from current beam and sort them by total score
expanded_scores = beam_scores[:, None] + dim_scores[None, :]
sorted_indices_ravel = (-expanded_scores).ravel().argsort()
sorted_indices = np.stack(np.unravel_index(sorted_indices_ravel, expanded_scores.shape), axis=-1)

sorted_prefix_uids = [f"{beam[row]}{self.UID_DELIMITER}{col:d}" for row, col in sorted_indices]
candidate_to_sorted_indices: Dict[str, Sequence[int]] = dict(zip(sorted_prefix_uids, sorted_indices))

# select k best candidates according to scores but only those that are still active
best_alive_prefixes: TOrderedDict[str, RemoteExpert] = await self._first_k_active(
node, sorted_prefix_uids, k=k_best, prefetch_rate=prefetch_rate,
time_budget=deadline_time - time.perf_counter(), **kwargs)

if not best_alive_prefixes:
logger.warning(f"Grid is empty: found neither of {sorted_prefix_uids}")
break

beam = list(best_alive_prefixes.keys())
beam_scores = expanded_scores[tuple(zip(*map(candidate_to_sorted_indices.get, beam)))]
beam_experts = list(best_alive_prefixes.values())

future.set_result(beam_experts)
return beam_experts

def first_k_active(self, uid_prefixes: List[str], k: int, prefetch_rate: int = 1, time_budget: float = float('inf'),
return_future: bool = False, **kwargs) -> Union[TOrderedDict[str, RemoteExpert], Awaitable]:
"""
Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search

:param uid_prefixes: a list of uid prefixes ordered from highest to lowest priority
:param uid_prefixes: a list of (unique) uid prefixes ordered from highest to lowest priority
:param k: return at most *this many* active prefixes
:param max_prefetch: pre-dispatch up to *this many* tasks (each for chunk_size experts)
:param chunk_size: dispatch this many requests in one task
:param return_future: if False (default), return when experts are returned. Otherwise return MPFuture.
:param prefetch_rate: dispatch up to this many GET requests in parallel for each uid_prefix
:param time_budget: if the procedure goes on for this many seconds, stop and return the best found experts
:param return_future: if True, returns MPFuture that can be awaited to get result, default = return result
:returns: a ordered dict{uid_prefix -> RemoteExpert} mapping at most :k: prefixes to matching experts
The keys in the returned dict are ordered same as in uid_prefixes.
"""
assert not isinstance(uid_prefixes, str), "please provide a list/tuple of prefixes as the first argument"
logger.warning("DHT.first_k_active is deprecated and will be removed in v0.9")
future, _future = MPFuture.make_pair()
self.pipe.send(('_first_k_active', [],
dict(uid_prefixes=uid_prefixes, k=k, max_prefetch=max_prefetch,
chunk_size=chunk_size or k, future=_future)))
self.pipe.send(('_first_k_active', [], dict(
uid_prefixes=uid_prefixes, k=k, future=_future, prefetch_rate=prefetch_rate or k,
time_budget=time_budget, **kwargs)))
return future if return_future else future.result()

async def _first_k_active(
self, node: DHTNode, uid_prefixes: List[str], k: int, max_prefetch: int, chunk_size: int, future: MPFuture):
num_workers_per_chunk = min(chunk_size, self.max_workers or chunk_size)
total_chunks = (len(uid_prefixes) - 1) // chunk_size + 1
found: List[Tuple[str, RemoteExpert]] = []

pending_tasks = deque(
asyncio.create_task(node.get_many(uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size],
num_workers=num_workers_per_chunk))
for chunk_i in range(min(max_prefetch + 1, total_chunks))
) # pre-dispatch first task and up to max_prefetch additional tasks

for chunk_i in range(total_chunks):
# parse task results in chronological order, launch additional tasks on demand
response = await pending_tasks.popleft()
for uid_prefix in uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size]:
maybe_expert_data, maybe_expiration_time = response[uid_prefix]
if maybe_expiration_time is not None: # found active peer
found.append((uid_prefix, RemoteExpert(**maybe_expert_data)))
# if we found enough active experts, finish immediately
if len(found) >= k:
break
if len(found) >= k:
break

pre_dispatch_chunk_i = chunk_i + len(pending_tasks) + 1
if pre_dispatch_chunk_i < total_chunks:
pending_tasks.append(asyncio.create_task(node.get_many(
uid_prefixes[pre_dispatch_chunk_i * chunk_size: (pre_dispatch_chunk_i + 1) * chunk_size],
num_workers=num_workers_per_chunk)))

for task in pending_tasks:
task.cancel()

# return k active prefixes or as many as we could find
future.set_result(OrderedDict(found))
async def _first_k_active(self, node: DHTNode, uid_prefixes: List[str], k: int, prefetch_rate: int,
time_budget: float, future=None, **kwargs) -> TOrderedDict[str, RemoteExpert]:
deadline_time = time.perf_counter() + time_budget
unattempted_uids_reversed = list(reversed(uid_prefixes))
pending_tasks: Deque[Tuple[str, asyncio.Task]] = deque()
found: TOrderedDict[str, RemoteExpert] = OrderedDict()

while len(found) < k and (unattempted_uids_reversed or pending_tasks):
# dispatch additional tasks
while unattempted_uids_reversed and len(pending_tasks) < (prefetch_rate * (k - len(found))):
uid_prefix = unattempted_uids_reversed.pop()
pending_tasks.append((uid_prefix, asyncio.create_task(node.get(uid_prefix, **kwargs))))

uid_prefix, getter_task = pending_tasks.popleft()
try:
maybe_expert_data, maybe_expiration_time = await asyncio.wait_for(
getter_task, timeout=deadline_time - time.perf_counter())
if maybe_expiration_time is not None: # found active expert
found[uid_prefix] = RemoteExpert(**maybe_expert_data)

except asyncio.TimeoutError:
# pick up pending tasks that have have already returned result
while len(found) < k and pending_tasks:
uid_prefix, getter_task = pending_tasks.popleft()
if getter_task.done():
maybe_expert_data, maybe_expiration_time = getter_task.result()
if maybe_expiration_time is not None: # found active expert
found[uid_prefix] = RemoteExpert(**maybe_expert_data)

# check for remaining uids in node's local storage/cache
while len(found) < k and unattempted_uids_reversed:
uid_prefix = unattempted_uids_reversed.pop()
maybe_expert_data, maybe_expiration_time = node.get_locally(uid_prefix)
if maybe_expiration_time is not None: # found active expert
found[uid_prefix] = RemoteExpert(**maybe_expert_data)

# cancel remaining tasks
for uid_prefix, getter_task in pending_tasks:
getter_task.cancel()

break # we ran out of time, return what we have

except asyncio.CancelledError:
for key, getter_task in pending_tasks:
getter_task.cancel()
raise

if future:
future.set_result(OrderedDict(found))
return OrderedDict(found)
17 changes: 17 additions & 0 deletions hivemind/dht/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,14 @@ async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID,
nearest_nodes_with_endpoints[query] = {node: node_to_endpoint[node] for node in nearest_nodes[:k_nearest]}
return nearest_nodes_with_endpoints

def store_locally(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration, in_cache=False) -> bool:
"""
(synchronous) Add key->value pair to this node's local storage or cache until expiration_time
Note: this does NOT guarantee that the key will be available to other peers. Use DHTNode.store for that.
"""
chosen_storage = self.protocol.cache if in_cache else self.protocol.storage
return chosen_storage.store(DHTID.generate(key), self.serializer.dumps(value), expiration_time)

async def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration, **kwargs) -> bool:
"""
Find num_replicas best nodes to store (key, value) and store it there at least until expiration time.
Expand Down Expand Up @@ -314,6 +322,15 @@ async def on_found(key_id: DHTID, nearest_nodes: List[DHTID], visited_nodes: Set
store_task.cancel()
raise e

def get_locally(self, key: DHTKey) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]:
""" (synchronous) Search for key in this node's local storage and cache, return latest (None if not found) """
key_id = DHTID.generate(source=key)
maybe_value_bytes, maybe_expiration = self.protocol.storage.get(key_id)
maybe_cached_value, maybe_cache_expiration = self.protocol.cache.get(key_id)
if (maybe_cache_expiration or -float('inf')) > (maybe_expiration or -float('inf')):
maybe_value_bytes, maybe_expiration = maybe_cached_value, maybe_cache_expiration
return maybe_value_bytes, maybe_expiration

async def get(self, key: DHTKey, latest=False, **kwargs) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]:
"""
Search for a key across DHT and return either first or latest entry.
Expand Down
2 changes: 1 addition & 1 deletion hivemind/dht/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def _remove_outdated(self):
def store(self, key: DHTID, value: BinaryDHTValue, expiration_time: DHTExpiration) -> bool:
"""
Store a (key, value) pair locally at least until expiration_time. See class docstring for details.
:returns: True if new value was stored, False it was rejected (current value is newer)
:returns: True if new value was stored, False it was rejected (e.g. if there is a newer value for that key)
"""
if expiration_time < get_dht_time() and not self.frozen:
return False
Expand Down