From 7b40e64a67f0a45e3a9d4aacb92974bab8bf9eee Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 15 Sep 2020 13:05:48 +0300 Subject: [PATCH 01/16] get key locally --- hivemind/dht/node.py | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/hivemind/dht/node.py b/hivemind/dht/node.py index 47218de00..f985f62c1 100644 --- a/hivemind/dht/node.py +++ b/hivemind/dht/node.py @@ -350,8 +350,9 @@ async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: O async def get_many_by_id( self, key_ids: Collection[DHTID], sufficient_expiration_time: Optional[DHTExpiration] = None, num_workers: Optional[int] = None, beam_size: Optional[int] = None, return_futures: bool = False, - _refresh_cache=True) -> Dict[DHTID, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]], - Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]: + _refresh_cache: bool = True, local_only: bool = False + ) -> Dict[DHTID, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]], + Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]: """ Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found. @@ -364,6 +365,7 @@ async def get_many_by_id( :param return_futures: if True, immediately return asyncio.Future for every before interacting with the nework. The algorithm will populate these futures with (value, expiration) when it finds the corresponding key Note: canceling a future will stop search for the corresponding key + :param local_only: if True, search only this node's own storage and cache and make NO network requests :param _refresh_cache: internal flag, whether or not to self._trigger_cache_refresh :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key :note: in order to check if get returned a value, please check (expiration_time is None) @@ -416,11 +418,16 @@ async def found_callback(key_id: DHTID, nearest_nodes: List[DHTID], _visited: Se search_results[key_id].finish_search() # finish search whether or we found something self._cache_new_result(search_results[key_id], nearest_nodes, node_to_endpoint) - asyncio.create_task(traverse_dht( - queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint), - beam_size=beam_size, num_workers=num_workers, queries_per_call=int(len(unfinished_key_ids) ** 0.5), - get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids}, - found_callback=found_callback, await_all_tasks=False)) + if not local_only: + asyncio.create_task(traverse_dht( + queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint), + beam_size=beam_size, num_workers=num_workers, queries_per_call=int(len(unfinished_key_ids) ** 0.5), + get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids}, + found_callback=found_callback, await_all_tasks=False)) + else: + for key_id in unfinished_key_ids: + search_results[key_id].finish_search() + if return_futures: return {key_id: search_result.future for key_id, search_result in search_results.items()} From 5f4086783ee5514d48283ac8649ab4dddbbf3155 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 15 Sep 2020 13:06:00 +0300 Subject: [PATCH 02/16] typo --- hivemind/dht/node.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hivemind/dht/node.py b/hivemind/dht/node.py index f985f62c1..59162e0a2 100644 --- a/hivemind/dht/node.py +++ b/hivemind/dht/node.py @@ -443,7 +443,7 @@ async def found_callback(key_id: DHTID, nearest_nodes: List[DHTID], _visited: Se def _reuse_finished_search_result(self, finished: _IntermediateResult): expiration_time_threshold = max(finished.expiration_time or -float('inf'), finished.sufficient_expiration_time) concurrent_requests: SortedList[_IntermediateResult] = self.pending_get_requests[finished.key_id] - # note: concurrent_requests is sorded in the order of descending sufficient_expiration_time + # note: concurrent_requests is sorted in the order of descending sufficient_expiration_time while concurrent_requests and expiration_time_threshold >= concurrent_requests[-1].sufficient_expiration_time: concurrent_requests[-1].add_candidate(finished.binary_value, finished.expiration_time, source_node_id=finished.source_node_id) From bbc7273b14852eb72a108998e4e89b37eebc07d1 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 15 Sep 2020 13:06:36 +0300 Subject: [PATCH 03/16] beam search on dht side --- hivemind/dht/__init__.py | 106 +++++++++++++++++++++++++++++++++++++++ 1 file changed, 106 insertions(+) diff --git a/hivemind/dht/__init__.py b/hivemind/dht/__init__.py index 54af9ea44..a65c33492 100644 --- a/hivemind/dht/__init__.py +++ b/hivemind/dht/__init__.py @@ -12,6 +12,7 @@ - [1] Maymounkov P., Mazieres D. (2002) Kademlia: A Peer-to-Peer Information System Based on the XOR Metric. - [2] https://github.com/bmuller/kademlia , Brian, if you're reading this: THANK YOU! you're awesome :) """ +import time import asyncio import ctypes import multiprocessing as mp @@ -21,6 +22,7 @@ from typing import List, Tuple, Optional, Sequence, OrderedDict as TOrderedDict, Union, Awaitable import uvloop +import torch from hivemind.client import RemoteExpert from hivemind.dht.node import DHTNode, DHTID, DHTExpiration @@ -239,3 +241,107 @@ async def _first_k_active( # return k active prefixes or as many as we could find future.set_result(OrderedDict(found)) + + def find_best_experts(self, prefix: str, grid_scores: List[torch.Tensor], k_best: int, *, + time_budget: float = float('inf'), grid_indices: Optional[List[torch.Tensor]] = None, + return_future=False, **kwargs) -> Union[List[RemoteExpert], MPFuture[List[RemoteExpert]]]: + """ + Find and return k active experts with highest scores, use both local cache and DHT + + + :param prefix: common prefix for all expert uids in grid + :param grid_scores: scores predicted for each dimension in the grid, + :type grid_scores: a sequence of tensors of shape[batch_size, grid_size[i]] + :param grid_indices: optional, indices for each grid dimension. Default = 0, 1, ... len(grid_scores[i]) - 1 + + :param k_best: how many best experts should beam search return + :param time_budget: how much time beam_search is can spend on queries to other peers (default = unlimited) + After time_budget is reached, beam search won't search for more experts and instead fall back on local cache + Please note that any queries that fall outside the budget will still be performed in background and cached + for subsequent iterations as long as DHTNode.cache_locally is True + :param return_future: if set to True, returns MPFuture that can be awaited to get the actual result + :param kwargs: extra keyword parameters passed to self.dht.first_k_active + :returns: a list that contains *up to* k_best RemoteExpert instances + """ + if grid_indices is None: + grid_indices = [torch.arange(len(dim_scores)) for dim_scores in grid_scores] + grid_scores = [dim_scores.cpu().detach() for dim_scores in grid_scores] + grid_indices = [dim_indices.cpu() for dim_indices in grid_indices] + assert len(grid_indices) == len(grid_scores), "grid_indices (if provided) must be of same length as grid_scores" + assert all(dim_scores.ndim == 1 and dim_scores.shape == dim_indices.shape + for dim, dim_scores, dim_indices in enumerate(zip(grid_scores, grid_indices))) + + future, _future = MPFuture.make_pair() + self.pipe.send(('_find_best_experts', [], + dict(prefix=prefix, grid_scores=grid_scores, grid_indices=grid_indices, k_best=k_best, + time_budget=time_budget, future=_future, **kwargs))) + return future if return_future else future.result() + + async def _find_best_experts(self, node: DHTNode, prefix: str, grid_scores: List[torch.Tensor], + grid_indices: List[torch.Tensor], k_best: int, time_budget: float = float('inf'), + future: Optional[MPFuture] = None, **kwargs) -> List[RemoteExpert]: + deadline_time = time.perf_counter() + time_budget + beam_experts: List[RemoteExpert] = [] + beam: List[str] = [prefix] + beam_scores = torch.zeros(1) + + for dim_index, dim_scores, dim_indices in enumerate(zip(grid_scores, grid_indices)): + # create all possible successors from current beam and sort them by total score + expanded_scores = beam_scores[:, None] + dim_scores[None, :] + sorted_indices = [(flat_i // len(dim_scores), flat_i % len(dim_scores)) + for flat_i in (-expanded_scores).flatten().argsort().numpy()] + + sorted_candidates = [f"{beam[row]}{self.UID_DELIMITER}{dim_indices[col]:d}" for row, col in sorted_indices] + candidate_to_sorted_indices = dict(zip(sorted_candidates, sorted_indices)) + + # select k best candidates according to scores but only those that are still active + best_alive_prefixes: TOrderedDict[str, RemoteExpert] = await self._first_k_active( + node, sorted_candidates, k=k_best, time_budget=deadline_time - time.perf_counter(), **kwargs) + + if not best_alive_prefixes: + logger.warning(f"Grid is empty: found neither of {sorted_candidates}") + break + + beam = list(best_alive_prefixes.keys()) + beam_scores = expanded_scores[tuple(zip(*map(candidate_to_sorted_indices.get, beam)))] + beam_experts = list(best_alive_prefixes.values()) + + if future: + future.set_result(beam_experts) + return beam_experts + + def batch_find_best_experts(self, prefix: str, grid_scores: List[torch.Tensor], k_best: int, *, + time_budget: float = float('inf'), grid_indices: Optional[List[torch.Tensor]], + return_future=False, **kwargs) -> List[RemoteExpert]: + """ + Batch-parallel version of find_best_experts (see find_best_experts docstring for details) + The only exception is that grid_scores must now be a list of 2d tensors [batch_size, grid_size[i]] + :returns: a list of batch_size lists, each contains *up to* k_best RemoteExpert instances + """ + if grid_indices is None: + grid_indices = [torch.arange(len(dim_scores)) for dim_scores in grid_scores] + grid_scores = [dim_scores.cpu().detach() for dim_scores in grid_scores] + grid_indices = [dim_indices.cpu() for dim_indices in grid_indices] + assert len(grid_indices) == len(grid_scores), "grid_indices (if provided) must be of same length as grid_scores" + batch_size = len(grid_scores[0]) + assert all(dim_scores.ndim == 2 and dim_indices.ndim == 1 and len(dim_scores) == len(dim_indices) == batch_size + for dim, dim_scores, dim_indices in enumerate(zip(grid_scores, grid_indices))) + future, _future = MPFuture.make_pair() + self.pipe.send(('_batch_find_best_experts', [], + dict(prefix=prefix, grid_scores=grid_scores, grid_indices=grid_indices, k_best=k_best, + time_budget=time_budget, future=_future, **kwargs))) + return future if return_future else future.result() + + async def _batch_find_best_experts(self, node: DHTNode, prefix: str, grid_scores: List[torch.Tensor], + grid_indices: List[torch.Tensor], k_best: int, time_budget: float = float('inf'), + future: Optional[MPFuture] = None, **kwargs) -> List[List[RemoteExpert]]: + results = await asyncio.gather(*[ + asyncio.create_task( + self._find_best_experts(node, prefix, grid_scores=grid_scores_i, grid_indices=grid_indices, + k_best=k_best, time_budget=time_budget, **kwargs) + ) for grid_scores_i in map(list, zip(*grid_scores)) + ]) + if future: + future.set_result(results) + return results + From b7b5f87935e7229e00440a243b30eecbdc57a78d Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 15 Sep 2020 22:03:02 +0300 Subject: [PATCH 04/16] _IntermediateResult -> _SearchState --- hivemind/dht/node.py | 124 ++++++++++++++++++++++--------------------- 1 file changed, 64 insertions(+), 60 deletions(-) diff --git a/hivemind/dht/node.py b/hivemind/dht/node.py index 59162e0a2..21ea70d29 100644 --- a/hivemind/dht/node.py +++ b/hivemind/dht/node.py @@ -54,7 +54,7 @@ class DHTNode: node_id: DHTID; is_alive: bool; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float cache_refresh_available: asyncio.Event; cache_refresh_queue: LocalStorage - reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedList[_IntermediateResult]] + reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedList[_SearchState]] serializer = MSGPackSerializer # used to pack/unpack DHT Values for transfer over network # fmt:on @@ -336,7 +336,7 @@ async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: O :param keys: traverse the DHT and find the value for each of these keys (or (None, None) if not key found) :param sufficient_expiration_time: if the search finds a value that expires after this time, default = time of call, find any value that did not expire by the time of call - If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration + If min_expiration_time == float('inf'), this method will find a value with _latest_ expiration :param kwargs: for full list of parameters, see DHTNode.get_many_by_id :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key :note: in order to check if get returned a value, please check if (expiration_time is None) @@ -371,30 +371,50 @@ async def get_many_by_id( :note: in order to check if get returned a value, please check (expiration_time is None) """ sufficient_expiration_time = sufficient_expiration_time or get_dht_time() - beam_size = beam_size if beam_size is not None else self.protocol.bucket_size - num_workers = num_workers if num_workers is not None else self.num_workers - search_results: Dict[DHTID, _IntermediateResult] = {key_id: _IntermediateResult( + search_states: Dict[DHTID, _SearchState] = {key_id: _SearchState( key_id, sufficient_expiration_time, serializer=self.serializer) for key_id in key_ids} if _refresh_cache: for key_id in key_ids: - search_results[key_id].add_done_callback(self._trigger_cache_refresh) + search_states[key_id].add_done_callback(self._trigger_cache_refresh) # if we have concurrent get request for some of the same keys, subscribe to their results if self.reuse_get_requests: - for key_id, search_result in search_results.items(): - self.pending_get_requests[key_id].add(search_result) - search_result.add_done_callback(self._reuse_finished_search_result) + for key_id, search_state in search_states.items(): + self.pending_get_requests[key_id].add(search_state) + search_state.add_done_callback(self._reuse_finished_search_result) # stage 1: check for value in this node's local storage and cache for key_id in key_ids: - search_results[key_id].add_candidate(*self.protocol.storage.get(key_id), source_node_id=self.node_id) - search_results[key_id].add_candidate(*self.protocol.cache.get(key_id), source_node_id=self.node_id) + search_states[key_id].add_candidate(*self.protocol.storage.get(key_id), source_node_id=self.node_id) + search_states[key_id].add_candidate(*self.protocol.cache.get(key_id), source_node_id=self.node_id) # stage 2: traverse the DHT to get the remaining keys from remote peers - unfinished_key_ids = [key_id for key_id in key_ids if not search_results[key_id].finished] + unfinished_search_results = {key_id: search for key_id, search in search_states.items() if not search.finished} + if not local_only: + asyncio.create_task(self._get_from_other_peers(unfinished_search_results, num_workers, beam_size)) + else: # if we're not allowed to traverse DHT, finish search right now + for key_id, search_state in unfinished_search_results.items(): + search_state.finish_search() + + if return_futures: + return {key_id: search_state.future for key_id, search_state in search_states.items()} + else: + try: + # note: this should be first time when we await something, there's no need to "try" the entire function + return {key_id: await search_state.future for key_id, search_state in search_states.items()} + except asyncio.CancelledError as e: # terminate remaining tasks ASAP + for key_id, search_state in search_states.items(): + search_state.future.cancel() + raise e + + async def _get_from_other_peers(self, search_states: Dict[DHTID, _SearchState], num_workers: Optional[int], + beam_size: Optional[int]) -> Awaitable: + """ Internal method: call traverse_dht to get keys from across DHT, add results to search_states """ + beam_size = beam_size if beam_size is not None else self.protocol.bucket_size + num_workers = num_workers if num_workers is not None else self.num_workers node_to_endpoint: Dict[DHTID, Endpoint] = dict() # global routing table for all keys - for key_id in unfinished_key_ids: + for key_id in search_states.keys(): node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors( key_id, self.protocol.bucket_size, exclude=self.node_id)) @@ -408,41 +428,25 @@ async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {} for key_id, (maybe_value_bytes, maybe_expiration_time, peers) in response.items(): node_to_endpoint.update(peers) - search_results[key_id].add_candidate(maybe_value_bytes, maybe_expiration_time, source_node_id=peer) - output[key_id] = tuple(peers.keys()), search_results[key_id].finished + search_states[key_id].add_candidate(maybe_value_bytes, maybe_expiration_time, source_node_id=peer) + output[key_id] = tuple(peers.keys()), search_states[key_id].finished # note: we interrupt search either if key is either found or finished otherwise (e.g. cancelled by user) return output # V-- this function will be called exactly once when traverse_dht finishes search for a given key async def found_callback(key_id: DHTID, nearest_nodes: List[DHTID], _visited: Set[DHTID]): - search_results[key_id].finish_search() # finish search whether or we found something - self._cache_new_result(search_results[key_id], nearest_nodes, node_to_endpoint) - - if not local_only: - asyncio.create_task(traverse_dht( - queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint), - beam_size=beam_size, num_workers=num_workers, queries_per_call=int(len(unfinished_key_ids) ** 0.5), - get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids}, - found_callback=found_callback, await_all_tasks=False)) - else: - for key_id in unfinished_key_ids: - search_results[key_id].finish_search() - + search_states[key_id].finish_search() # finish search whether or we found something + self._cache_new_result(search_states[key_id], nearest_nodes, node_to_endpoint) - if return_futures: - return {key_id: search_result.future for key_id, search_result in search_results.items()} - else: - try: - # note: this should be first time when we await something, there's no need to "try" the entire function - return {key_id: await search_result.future for key_id, search_result in search_results.items()} - except asyncio.CancelledError as e: # terminate remaining tasks ASAP - for key_id, search_result in search_results.items(): - search_result.future.cancel() - raise e + return traverse_dht( + queries=list(search_states.keys()), initial_nodes=list(node_to_endpoint), + beam_size=beam_size, num_workers=num_workers, queries_per_call=int(len(search_states) ** 0.5), + get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in search_states.keys()}, + found_callback=found_callback, await_all_tasks=False) - def _reuse_finished_search_result(self, finished: _IntermediateResult): + def _reuse_finished_search_result(self, finished: _SearchState): expiration_time_threshold = max(finished.expiration_time or -float('inf'), finished.sufficient_expiration_time) - concurrent_requests: SortedList[_IntermediateResult] = self.pending_get_requests[finished.key_id] + concurrent_requests: SortedList[_SearchState] = self.pending_get_requests[finished.key_id] # note: concurrent_requests is sorted in the order of descending sufficient_expiration_time while concurrent_requests and expiration_time_threshold >= concurrent_requests[-1].sufficient_expiration_time: concurrent_requests[-1].add_candidate(finished.binary_value, finished.expiration_time, @@ -450,14 +454,14 @@ def _reuse_finished_search_result(self, finished: _IntermediateResult): concurrent_requests[-1].finish_search() concurrent_requests.pop(-1) - def _trigger_cache_refresh(self, result: _IntermediateResult): + def _trigger_cache_refresh(self, search: _SearchState): """ Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """ - if result.found_something and result.source_node_id == self.node_id: + if search.found_something and search.source_node_id == self.node_id: with self.protocol.cache.freeze(): # do not clear outdated cache for now... - if self.cache_refresh_before_expiry and result.key_id in self.protocol.cache: + if self.cache_refresh_before_expiry and search.key_id in self.protocol.cache: previous_earliest_item: Tuple[DHTID, BinaryDHTValue, DHTExpiration] = self.cache_refresh_queue.top() - self.cache_refresh_queue.store(result.key_id, result.binary_value, result.expiration_time) - if previous_earliest_item is None or result.expiration_time < previous_earliest_item[-1]: + self.cache_refresh_queue.store(search.key_id, search.binary_value, search.expiration_time) + if previous_earliest_item is None or search.expiration_time < previous_earliest_item[-1]: self.cache_refresh_available.set() # if we new element is now earliest, notify the cache queue async def _refresh_stale_cache_entries(self): @@ -494,22 +498,22 @@ async def _refresh_stale_cache_entries(self): keys_to_refresh, sufficient_expiration_time=nearest_expiration + self.cache_refresh_before_expiry, _refresh_cache=False) # if we found value locally, we shouldn't trigger another refresh - def _cache_new_result(self, result: _IntermediateResult, nearest_nodes: List[DHTID], + def _cache_new_result(self, search: _SearchState, nearest_nodes: List[DHTID], node_to_endpoint: Dict[DHTID, Endpoint]): """ after key_id is found, update cache according to caching policy. used internally in get and get_many """ - if result.found_something: - previous_expiration_time = max(self.protocol.storage.get(result.key_id)[1] or -float('inf'), - self.protocol.cache.get(result.key_id)[1] or -float('inf')) - if result.expiration_time > previous_expiration_time: # if this value has better expiration + if search.found_something: + previous_expiration_time = max(self.protocol.storage.get(search.key_id)[1] or -float('inf'), + self.protocol.cache.get(search.key_id)[1] or -float('inf')) + if search.expiration_time > previous_expiration_time: # if this value has better expiration if self.cache_locally: - self.protocol.cache.store(result.key_id, result.binary_value, result.expiration_time) + self.protocol.cache.store(search.key_id, search.binary_value, search.expiration_time) if self.cache_nearest: num_cached_nodes = 0 for node_id in nearest_nodes: - if node_id == result.source_node_id: + if node_id == search.source_node_id: continue asyncio.create_task(self.protocol.call_store( - node_to_endpoint[node_id], [result.key_id], [result.binary_value], [result.expiration_time], + node_to_endpoint[node_id], [search.key_id], [search.binary_value], [search.expiration_time], in_cache=True)) num_cached_nodes += 1 if num_cached_nodes >= self.cache_nearest: @@ -530,8 +534,8 @@ async def _refresh_routing_table(self, *, period: Optional[float]) -> None: @dataclass(init=True, repr=True, frozen=False, order=False) -class _IntermediateResult: - """ A helper class that stores current-best GET results with metadata """ +class _SearchState: + """ A helper class that stores current-best GET results and all search metadata """ key_id: DHTID sufficient_expiration_time: DHTExpiration binary_value: Optional[BinaryDHTValue] = None @@ -547,13 +551,13 @@ def add_candidate(self, binary_value: Optional[BinaryDHTValue], expiration_time: if self.expiration_time >= self.sufficient_expiration_time: self.finish_search() - def add_done_callback(self, callback: Callable[[_IntermediateResult], Any]): - """ Add callback that will be called when _IntermediateSearchResult is done (found OR cancelled by user) """ + def add_done_callback(self, callback: Callable[[_SearchState], Any]): + """ Add callback that will be called when _SearchState is done (found OR cancelled by user) """ self.future.add_done_callback(lambda _future: callback(self)) def finish_search(self): if self.future.done(): - return # either user cancelled our result or someone sent it before us. Nothing more to do here. + return # either user cancelled our search or someone sent it before us. Nothing more to do here. deserialized_value = self.serializer.loads(self.binary_value) if self.found_something else None self.future.set_result((deserialized_value, self.expiration_time)) @@ -566,6 +570,6 @@ def found_something(self) -> bool: def finished(self) -> bool: return self.future.done() - def __lt__(self, other: _IntermediateResult): - """ _IntermediateResult instances will be sorted by their target expiration time """ + def __lt__(self, other: _SearchState): + """ _SearchState instances will be sorted by their target expiration time """ return self.sufficient_expiration_time < other.sufficient_expiration_time From d3b00d45d28bb085c8e392dfdf00504aa188cb1c Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 15 Sep 2020 22:17:22 +0300 Subject: [PATCH 05/16] rollback local_only, rollback _SearchState --- hivemind/dht/node.py | 127 ++++++++++++++++++++----------------------- 1 file changed, 58 insertions(+), 69 deletions(-) diff --git a/hivemind/dht/node.py b/hivemind/dht/node.py index 21ea70d29..47218de00 100644 --- a/hivemind/dht/node.py +++ b/hivemind/dht/node.py @@ -54,7 +54,7 @@ class DHTNode: node_id: DHTID; is_alive: bool; port: int; num_replicas: int; num_workers: int; protocol: DHTProtocol refresh_timeout: float; cache_locally: bool; cache_nearest: int; cache_refresh_before_expiry: float cache_refresh_available: asyncio.Event; cache_refresh_queue: LocalStorage - reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedList[_SearchState]] + reuse_get_requests: bool; pending_get_requests: DefaultDict[DHTID, SortedList[_IntermediateResult]] serializer = MSGPackSerializer # used to pack/unpack DHT Values for transfer over network # fmt:on @@ -336,7 +336,7 @@ async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: O :param keys: traverse the DHT and find the value for each of these keys (or (None, None) if not key found) :param sufficient_expiration_time: if the search finds a value that expires after this time, default = time of call, find any value that did not expire by the time of call - If min_expiration_time == float('inf'), this method will find a value with _latest_ expiration + If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration :param kwargs: for full list of parameters, see DHTNode.get_many_by_id :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key :note: in order to check if get returned a value, please check if (expiration_time is None) @@ -350,9 +350,8 @@ async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: O async def get_many_by_id( self, key_ids: Collection[DHTID], sufficient_expiration_time: Optional[DHTExpiration] = None, num_workers: Optional[int] = None, beam_size: Optional[int] = None, return_futures: bool = False, - _refresh_cache: bool = True, local_only: bool = False - ) -> Dict[DHTID, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]], - Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]: + _refresh_cache=True) -> Dict[DHTID, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]], + Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]: """ Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found. @@ -365,56 +364,35 @@ async def get_many_by_id( :param return_futures: if True, immediately return asyncio.Future for every before interacting with the nework. The algorithm will populate these futures with (value, expiration) when it finds the corresponding key Note: canceling a future will stop search for the corresponding key - :param local_only: if True, search only this node's own storage and cache and make NO network requests :param _refresh_cache: internal flag, whether or not to self._trigger_cache_refresh :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key :note: in order to check if get returned a value, please check (expiration_time is None) """ sufficient_expiration_time = sufficient_expiration_time or get_dht_time() - search_states: Dict[DHTID, _SearchState] = {key_id: _SearchState( + beam_size = beam_size if beam_size is not None else self.protocol.bucket_size + num_workers = num_workers if num_workers is not None else self.num_workers + search_results: Dict[DHTID, _IntermediateResult] = {key_id: _IntermediateResult( key_id, sufficient_expiration_time, serializer=self.serializer) for key_id in key_ids} if _refresh_cache: for key_id in key_ids: - search_states[key_id].add_done_callback(self._trigger_cache_refresh) + search_results[key_id].add_done_callback(self._trigger_cache_refresh) # if we have concurrent get request for some of the same keys, subscribe to their results if self.reuse_get_requests: - for key_id, search_state in search_states.items(): - self.pending_get_requests[key_id].add(search_state) - search_state.add_done_callback(self._reuse_finished_search_result) + for key_id, search_result in search_results.items(): + self.pending_get_requests[key_id].add(search_result) + search_result.add_done_callback(self._reuse_finished_search_result) # stage 1: check for value in this node's local storage and cache for key_id in key_ids: - search_states[key_id].add_candidate(*self.protocol.storage.get(key_id), source_node_id=self.node_id) - search_states[key_id].add_candidate(*self.protocol.cache.get(key_id), source_node_id=self.node_id) + search_results[key_id].add_candidate(*self.protocol.storage.get(key_id), source_node_id=self.node_id) + search_results[key_id].add_candidate(*self.protocol.cache.get(key_id), source_node_id=self.node_id) # stage 2: traverse the DHT to get the remaining keys from remote peers - unfinished_search_results = {key_id: search for key_id, search in search_states.items() if not search.finished} - if not local_only: - asyncio.create_task(self._get_from_other_peers(unfinished_search_results, num_workers, beam_size)) - else: # if we're not allowed to traverse DHT, finish search right now - for key_id, search_state in unfinished_search_results.items(): - search_state.finish_search() - - if return_futures: - return {key_id: search_state.future for key_id, search_state in search_states.items()} - else: - try: - # note: this should be first time when we await something, there's no need to "try" the entire function - return {key_id: await search_state.future for key_id, search_state in search_states.items()} - except asyncio.CancelledError as e: # terminate remaining tasks ASAP - for key_id, search_state in search_states.items(): - search_state.future.cancel() - raise e - - async def _get_from_other_peers(self, search_states: Dict[DHTID, _SearchState], num_workers: Optional[int], - beam_size: Optional[int]) -> Awaitable: - """ Internal method: call traverse_dht to get keys from across DHT, add results to search_states """ - beam_size = beam_size if beam_size is not None else self.protocol.bucket_size - num_workers = num_workers if num_workers is not None else self.num_workers + unfinished_key_ids = [key_id for key_id in key_ids if not search_results[key_id].finished] node_to_endpoint: Dict[DHTID, Endpoint] = dict() # global routing table for all keys - for key_id in search_states.keys(): + for key_id in unfinished_key_ids: node_to_endpoint.update(self.protocol.routing_table.get_nearest_neighbors( key_id, self.protocol.bucket_size, exclude=self.node_id)) @@ -428,40 +406,51 @@ async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, output: Dict[DHTID, Tuple[Tuple[DHTID], bool]] = {} for key_id, (maybe_value_bytes, maybe_expiration_time, peers) in response.items(): node_to_endpoint.update(peers) - search_states[key_id].add_candidate(maybe_value_bytes, maybe_expiration_time, source_node_id=peer) - output[key_id] = tuple(peers.keys()), search_states[key_id].finished + search_results[key_id].add_candidate(maybe_value_bytes, maybe_expiration_time, source_node_id=peer) + output[key_id] = tuple(peers.keys()), search_results[key_id].finished # note: we interrupt search either if key is either found or finished otherwise (e.g. cancelled by user) return output # V-- this function will be called exactly once when traverse_dht finishes search for a given key async def found_callback(key_id: DHTID, nearest_nodes: List[DHTID], _visited: Set[DHTID]): - search_states[key_id].finish_search() # finish search whether or we found something - self._cache_new_result(search_states[key_id], nearest_nodes, node_to_endpoint) + search_results[key_id].finish_search() # finish search whether or we found something + self._cache_new_result(search_results[key_id], nearest_nodes, node_to_endpoint) - return traverse_dht( - queries=list(search_states.keys()), initial_nodes=list(node_to_endpoint), - beam_size=beam_size, num_workers=num_workers, queries_per_call=int(len(search_states) ** 0.5), - get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in search_states.keys()}, - found_callback=found_callback, await_all_tasks=False) + asyncio.create_task(traverse_dht( + queries=list(unfinished_key_ids), initial_nodes=list(node_to_endpoint), + beam_size=beam_size, num_workers=num_workers, queries_per_call=int(len(unfinished_key_ids) ** 0.5), + get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids}, + found_callback=found_callback, await_all_tasks=False)) + + if return_futures: + return {key_id: search_result.future for key_id, search_result in search_results.items()} + else: + try: + # note: this should be first time when we await something, there's no need to "try" the entire function + return {key_id: await search_result.future for key_id, search_result in search_results.items()} + except asyncio.CancelledError as e: # terminate remaining tasks ASAP + for key_id, search_result in search_results.items(): + search_result.future.cancel() + raise e - def _reuse_finished_search_result(self, finished: _SearchState): + def _reuse_finished_search_result(self, finished: _IntermediateResult): expiration_time_threshold = max(finished.expiration_time or -float('inf'), finished.sufficient_expiration_time) - concurrent_requests: SortedList[_SearchState] = self.pending_get_requests[finished.key_id] - # note: concurrent_requests is sorted in the order of descending sufficient_expiration_time + concurrent_requests: SortedList[_IntermediateResult] = self.pending_get_requests[finished.key_id] + # note: concurrent_requests is sorded in the order of descending sufficient_expiration_time while concurrent_requests and expiration_time_threshold >= concurrent_requests[-1].sufficient_expiration_time: concurrent_requests[-1].add_candidate(finished.binary_value, finished.expiration_time, source_node_id=finished.source_node_id) concurrent_requests[-1].finish_search() concurrent_requests.pop(-1) - def _trigger_cache_refresh(self, search: _SearchState): + def _trigger_cache_refresh(self, result: _IntermediateResult): """ Called after get request is finished (whether it was found, not found, hit cache, cancelled, or reused) """ - if search.found_something and search.source_node_id == self.node_id: + if result.found_something and result.source_node_id == self.node_id: with self.protocol.cache.freeze(): # do not clear outdated cache for now... - if self.cache_refresh_before_expiry and search.key_id in self.protocol.cache: + if self.cache_refresh_before_expiry and result.key_id in self.protocol.cache: previous_earliest_item: Tuple[DHTID, BinaryDHTValue, DHTExpiration] = self.cache_refresh_queue.top() - self.cache_refresh_queue.store(search.key_id, search.binary_value, search.expiration_time) - if previous_earliest_item is None or search.expiration_time < previous_earliest_item[-1]: + self.cache_refresh_queue.store(result.key_id, result.binary_value, result.expiration_time) + if previous_earliest_item is None or result.expiration_time < previous_earliest_item[-1]: self.cache_refresh_available.set() # if we new element is now earliest, notify the cache queue async def _refresh_stale_cache_entries(self): @@ -498,22 +487,22 @@ async def _refresh_stale_cache_entries(self): keys_to_refresh, sufficient_expiration_time=nearest_expiration + self.cache_refresh_before_expiry, _refresh_cache=False) # if we found value locally, we shouldn't trigger another refresh - def _cache_new_result(self, search: _SearchState, nearest_nodes: List[DHTID], + def _cache_new_result(self, result: _IntermediateResult, nearest_nodes: List[DHTID], node_to_endpoint: Dict[DHTID, Endpoint]): """ after key_id is found, update cache according to caching policy. used internally in get and get_many """ - if search.found_something: - previous_expiration_time = max(self.protocol.storage.get(search.key_id)[1] or -float('inf'), - self.protocol.cache.get(search.key_id)[1] or -float('inf')) - if search.expiration_time > previous_expiration_time: # if this value has better expiration + if result.found_something: + previous_expiration_time = max(self.protocol.storage.get(result.key_id)[1] or -float('inf'), + self.protocol.cache.get(result.key_id)[1] or -float('inf')) + if result.expiration_time > previous_expiration_time: # if this value has better expiration if self.cache_locally: - self.protocol.cache.store(search.key_id, search.binary_value, search.expiration_time) + self.protocol.cache.store(result.key_id, result.binary_value, result.expiration_time) if self.cache_nearest: num_cached_nodes = 0 for node_id in nearest_nodes: - if node_id == search.source_node_id: + if node_id == result.source_node_id: continue asyncio.create_task(self.protocol.call_store( - node_to_endpoint[node_id], [search.key_id], [search.binary_value], [search.expiration_time], + node_to_endpoint[node_id], [result.key_id], [result.binary_value], [result.expiration_time], in_cache=True)) num_cached_nodes += 1 if num_cached_nodes >= self.cache_nearest: @@ -534,8 +523,8 @@ async def _refresh_routing_table(self, *, period: Optional[float]) -> None: @dataclass(init=True, repr=True, frozen=False, order=False) -class _SearchState: - """ A helper class that stores current-best GET results and all search metadata """ +class _IntermediateResult: + """ A helper class that stores current-best GET results with metadata """ key_id: DHTID sufficient_expiration_time: DHTExpiration binary_value: Optional[BinaryDHTValue] = None @@ -551,13 +540,13 @@ def add_candidate(self, binary_value: Optional[BinaryDHTValue], expiration_time: if self.expiration_time >= self.sufficient_expiration_time: self.finish_search() - def add_done_callback(self, callback: Callable[[_SearchState], Any]): - """ Add callback that will be called when _SearchState is done (found OR cancelled by user) """ + def add_done_callback(self, callback: Callable[[_IntermediateResult], Any]): + """ Add callback that will be called when _IntermediateSearchResult is done (found OR cancelled by user) """ self.future.add_done_callback(lambda _future: callback(self)) def finish_search(self): if self.future.done(): - return # either user cancelled our search or someone sent it before us. Nothing more to do here. + return # either user cancelled our result or someone sent it before us. Nothing more to do here. deserialized_value = self.serializer.loads(self.binary_value) if self.found_something else None self.future.set_result((deserialized_value, self.expiration_time)) @@ -570,6 +559,6 @@ def found_something(self) -> bool: def finished(self) -> bool: return self.future.done() - def __lt__(self, other: _SearchState): - """ _SearchState instances will be sorted by their target expiration time """ + def __lt__(self, other: _IntermediateResult): + """ _IntermediateResult instances will be sorted by their target expiration time """ return self.sufficient_expiration_time < other.sufficient_expiration_time From 9c539d32be5ef1442f441d1adaf528d7d84fca81 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Tue, 15 Sep 2020 22:34:43 +0300 Subject: [PATCH 06/16] remove return_futures from DHTNode --- hivemind/dht/node.py | 40 ++++++++++++++++++++++------------------ 1 file changed, 22 insertions(+), 18 deletions(-) diff --git a/hivemind/dht/node.py b/hivemind/dht/node.py index 47218de00..aab6b4f1b 100644 --- a/hivemind/dht/node.py +++ b/hivemind/dht/node.py @@ -327,9 +327,17 @@ async def get(self, key: DHTKey, latest=False, **kwargs) -> Tuple[Optional[DHTVa result = await self.get_many([key]) return result[key] + def get_local(self, key: DHTKey) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]: + """ Like DHTNode.get, but only search for key in node's local storage and cache """ + key_id = DHTID.generate(source=key) + maybe_value_bytes, maybe_expiration = self.protocol.storage.get(key_id) + maybe_cached_value, maybe_cache_expiration = self.protocol.cache.get(key_id) + if (maybe_cache_expiration or -float('inf')) > (maybe_expiration or -float('inf')): + maybe_value_bytes, maybe_expiration = maybe_cached_value, maybe_cache_expiration + return maybe_value_bytes, maybe_expiration + async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None, - **kwargs) -> Dict[DHTKey, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]], - Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]: + **kwargs) -> Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]]: """ Traverse DHT to find a list of keys. For each key, return latest (value, expiration) or None if not found. @@ -349,9 +357,8 @@ async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: O async def get_many_by_id( self, key_ids: Collection[DHTID], sufficient_expiration_time: Optional[DHTExpiration] = None, - num_workers: Optional[int] = None, beam_size: Optional[int] = None, return_futures: bool = False, - _refresh_cache=True) -> Dict[DHTID, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]], - Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]: + num_workers: Optional[int] = None, beam_size: Optional[int] = None, _refresh_cache=True + ) -> Dict[DHTID, Tuple[Optional[DHTValue], Optional[DHTExpiration]]]: """ Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found. @@ -361,9 +368,6 @@ async def get_many_by_id( If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration :param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size :param num_workers: override for default num_workers, see traverse_dht num_workers param - :param return_futures: if True, immediately return asyncio.Future for every before interacting with the nework. - The algorithm will populate these futures with (value, expiration) when it finds the corresponding key - Note: canceling a future will stop search for the corresponding key :param _refresh_cache: internal flag, whether or not to self._trigger_cache_refresh :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key :note: in order to check if get returned a value, please check (expiration_time is None) @@ -422,16 +426,13 @@ async def found_callback(key_id: DHTID, nearest_nodes: List[DHTID], _visited: Se get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids}, found_callback=found_callback, await_all_tasks=False)) - if return_futures: - return {key_id: search_result.future for key_id, search_result in search_results.items()} - else: - try: - # note: this should be first time when we await something, there's no need to "try" the entire function - return {key_id: await search_result.future for key_id, search_result in search_results.items()} - except asyncio.CancelledError as e: # terminate remaining tasks ASAP - for key_id, search_result in search_results.items(): - search_result.future.cancel() - raise e + try: + # note: this should be first time when we await something, there's no need to "try" the entire function + return {key_id: await search_result for key_id, search_result in search_results.items()} + except asyncio.CancelledError as e: # terminate remaining tasks ASAP + for key_id, search_result in search_results.items(): + search_result.finish_search() + raise e def _reuse_finished_search_result(self, finished: _IntermediateResult): expiration_time_threshold = max(finished.expiration_time or -float('inf'), finished.sufficient_expiration_time) @@ -562,3 +563,6 @@ def finished(self) -> bool: def __lt__(self, other: _IntermediateResult): """ _IntermediateResult instances will be sorted by their target expiration time """ return self.sufficient_expiration_time < other.sufficient_expiration_time + + def __await__(self): + return self.future.__await__() \ No newline at end of file From 15b1d4154ba5ceab79cae724abb57567aee36717 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 16 Sep 2020 14:09:47 +0300 Subject: [PATCH 07/16] better end condition in first_k_active --- hivemind/dht/__init__.py | 208 ++++++++++++++++----------------------- 1 file changed, 86 insertions(+), 122 deletions(-) diff --git a/hivemind/dht/__init__.py b/hivemind/dht/__init__.py index a65c33492..7a5207ef9 100644 --- a/hivemind/dht/__init__.py +++ b/hivemind/dht/__init__.py @@ -19,10 +19,10 @@ import warnings from collections import deque, OrderedDict from concurrent.futures import ThreadPoolExecutor -from typing import List, Tuple, Optional, Sequence, OrderedDict as TOrderedDict, Union, Awaitable +from typing import List, Tuple, Optional, Sequence, OrderedDict as TOrderedDict, Union, Awaitable, Dict, Deque import uvloop -import torch +import numpy as np from hivemind.client import RemoteExpert from hivemind.dht.node import DHTNode, DHTID, DHTExpiration @@ -68,6 +68,7 @@ class DHT(mp.Process): indices from the previous step. Finally, MoE will use DHT.get_experts(uids: List[str]) search for specific experts. This beam search explores one additional dimension per step and finds k best experts from across the DHT in O(k / s * log(N)) average time where s is grid sparsity rate and N is the total number of experts. + TODO(jheuristic) replace _first_k_active with beam search description! """ UID_DELIMITER = '.' # when declaring experts, DHT store all prefixes of that expert's uid, split over this prefix @@ -184,76 +185,15 @@ async def _declare_experts(self, node: DHTNode, uids: List[str], endpoint: Endpo if future is not None: future.set_result([store_ok[key] for key in data_to_store.keys()]) - def first_k_active( - self, uid_prefixes: List[str], k: int, max_prefetch: int = 1, chunk_size: Optional[int] = None, - return_future=False) -> Union[TOrderedDict[str, RemoteExpert], Awaitable[TOrderedDict[str, RemoteExpert]]]: + def find_best_experts( + self, prefix: str, grid_scores: List[np.ndarray], k_best: int, *, time_budget: float = float('inf'), + return_future=False, **kwargs) -> Union[List[RemoteExpert], MPFuture]: """ - Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search - - :param uid_prefixes: a list of uid prefixes ordered from highest to lowest priority - :param k: return at most *this many* active prefixes - :param max_prefetch: pre-dispatch up to *this many* tasks (each for chunk_size experts) - :param chunk_size: dispatch this many requests in one task - :param return_future: if False (default), return when experts are returned. Otherwise return MPFuture. - :returns: a ordered dict{uid_prefix -> RemoteExpert} mapping at most :k: prefixes to matching experts - The keys in the returned dict are ordered same as in uid_prefixes. - """ - assert not isinstance(uid_prefixes, str), "please provide a list/tuple of prefixes as the first argument" - future, _future = MPFuture.make_pair() - self.pipe.send(('_first_k_active', [], - dict(uid_prefixes=uid_prefixes, k=k, max_prefetch=max_prefetch, - chunk_size=chunk_size or k, future=_future))) - return future if return_future else future.result() - - async def _first_k_active( - self, node: DHTNode, uid_prefixes: List[str], k: int, max_prefetch: int, chunk_size: int, future: MPFuture): - num_workers_per_chunk = min(chunk_size, self.max_workers or chunk_size) - total_chunks = (len(uid_prefixes) - 1) // chunk_size + 1 - found: List[Tuple[str, RemoteExpert]] = [] - - pending_tasks = deque( - asyncio.create_task(node.get_many(uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size], - num_workers=num_workers_per_chunk)) - for chunk_i in range(min(max_prefetch + 1, total_chunks)) - ) # pre-dispatch first task and up to max_prefetch additional tasks - - for chunk_i in range(total_chunks): - # parse task results in chronological order, launch additional tasks on demand - response = await pending_tasks.popleft() - for uid_prefix in uid_prefixes[chunk_i * chunk_size: (chunk_i + 1) * chunk_size]: - maybe_expert_data, maybe_expiration_time = response[uid_prefix] - if maybe_expiration_time is not None: # found active peer - found.append((uid_prefix, RemoteExpert(**maybe_expert_data))) - # if we found enough active experts, finish immediately - if len(found) >= k: - break - if len(found) >= k: - break - - pre_dispatch_chunk_i = chunk_i + len(pending_tasks) + 1 - if pre_dispatch_chunk_i < total_chunks: - pending_tasks.append(asyncio.create_task(node.get_many( - uid_prefixes[pre_dispatch_chunk_i * chunk_size: (pre_dispatch_chunk_i + 1) * chunk_size], - num_workers=num_workers_per_chunk))) - - for task in pending_tasks: - task.cancel() - - # return k active prefixes or as many as we could find - future.set_result(OrderedDict(found)) - - def find_best_experts(self, prefix: str, grid_scores: List[torch.Tensor], k_best: int, *, - time_budget: float = float('inf'), grid_indices: Optional[List[torch.Tensor]] = None, - return_future=False, **kwargs) -> Union[List[RemoteExpert], MPFuture[List[RemoteExpert]]]: - """ - Find and return k active experts with highest scores, use both local cache and DHT - + Find and return k_best active experts with highest scores, use both local cache and DHT :param prefix: common prefix for all expert uids in grid :param grid_scores: scores predicted for each dimension in the grid, - :type grid_scores: a sequence of tensors of shape[batch_size, grid_size[i]] - :param grid_indices: optional, indices for each grid dimension. Default = 0, 1, ... len(grid_scores[i]) - 1 - + :type grid_scores: model scores for each grid dimension, list of arrays of shape grid_size[i] :param k_best: how many best experts should beam search return :param time_budget: how much time beam_search is can spend on queries to other peers (default = unlimited) After time_budget is reached, beam search won't search for more experts and instead fall back on local cache @@ -263,85 +203,109 @@ def find_best_experts(self, prefix: str, grid_scores: List[torch.Tensor], k_best :param kwargs: extra keyword parameters passed to self.dht.first_k_active :returns: a list that contains *up to* k_best RemoteExpert instances """ - if grid_indices is None: - grid_indices = [torch.arange(len(dim_scores)) for dim_scores in grid_scores] - grid_scores = [dim_scores.cpu().detach() for dim_scores in grid_scores] - grid_indices = [dim_indices.cpu() for dim_indices in grid_indices] - assert len(grid_indices) == len(grid_scores), "grid_indices (if provided) must be of same length as grid_scores" - assert all(dim_scores.ndim == 1 and dim_scores.shape == dim_indices.shape - for dim, dim_scores, dim_indices in enumerate(zip(grid_scores, grid_indices))) + grid_scores = list(map(np.asanyrray, grid_scores)) + assert all(dim_scores.ndim == 1 for dim_scores in grid_scores) future, _future = MPFuture.make_pair() - self.pipe.send(('_find_best_experts', [], - dict(prefix=prefix, grid_scores=grid_scores, grid_indices=grid_indices, k_best=k_best, - time_budget=time_budget, future=_future, **kwargs))) + self.pipe.send(('_find_best_experts', [], dict(prefix=prefix, grid_scores=grid_scores, k_best=k_best, + time_budget=time_budget, future=_future, **kwargs))) return future if return_future else future.result() - async def _find_best_experts(self, node: DHTNode, prefix: str, grid_scores: List[torch.Tensor], - grid_indices: List[torch.Tensor], k_best: int, time_budget: float = float('inf'), - future: Optional[MPFuture] = None, **kwargs) -> List[RemoteExpert]: + async def _find_best_experts(self, node: DHTNode, prefix: str, grid_scores: List[np.ndarray], k_best: int, + time_budget: float, future: MPFuture, **kwargs) -> List[RemoteExpert]: deadline_time = time.perf_counter() + time_budget beam_experts: List[RemoteExpert] = [] beam: List[str] = [prefix] - beam_scores = torch.zeros(1) + beam_scores = np.zeros(1) - for dim_index, dim_scores, dim_indices in enumerate(zip(grid_scores, grid_indices)): + for dim_index, dim_scores in enumerate(grid_scores): # create all possible successors from current beam and sort them by total score expanded_scores = beam_scores[:, None] + dim_scores[None, :] - sorted_indices = [(flat_i // len(dim_scores), flat_i % len(dim_scores)) - for flat_i in (-expanded_scores).flatten().argsort().numpy()] + sorted_indices_ravel = (-expanded_scores).ravel().argsort() + sorted_indices = np.stack(np.unravel_index(sorted_indices_ravel, expanded_scores.shape), axis=-1) - sorted_candidates = [f"{beam[row]}{self.UID_DELIMITER}{dim_indices[col]:d}" for row, col in sorted_indices] - candidate_to_sorted_indices = dict(zip(sorted_candidates, sorted_indices)) + sorted_prefix_uids = [f"{beam[row]}{self.UID_DELIMITER}{col:d}" for row, col in sorted_indices] + candidate_to_sorted_indices: Dict[str, Sequence[int]] = dict(zip(sorted_prefix_uids, sorted_indices)) # select k best candidates according to scores but only those that are still active best_alive_prefixes: TOrderedDict[str, RemoteExpert] = await self._first_k_active( - node, sorted_candidates, k=k_best, time_budget=deadline_time - time.perf_counter(), **kwargs) + node, sorted_prefix_uids, k=k_best, time_budget=deadline_time - time.perf_counter(), **kwargs) if not best_alive_prefixes: - logger.warning(f"Grid is empty: found neither of {sorted_candidates}") + logger.warning(f"Grid is empty: found neither of {sorted_prefix_uids}") break beam = list(best_alive_prefixes.keys()) beam_scores = expanded_scores[tuple(zip(*map(candidate_to_sorted_indices.get, beam)))] beam_experts = list(best_alive_prefixes.values()) - if future: - future.set_result(beam_experts) + future.set_result(beam_experts) return beam_experts - def batch_find_best_experts(self, prefix: str, grid_scores: List[torch.Tensor], k_best: int, *, - time_budget: float = float('inf'), grid_indices: Optional[List[torch.Tensor]], - return_future=False, **kwargs) -> List[RemoteExpert]: + def first_k_active(self, uid_prefixes: List[str], k: int, **kwargs): + """ TODO(jheuristic) remove this, also remove future arg from _first_k_active """ + future, _future = MPFuture.make_pair() + self.pipe.send(('_first_k_active', [], dict(uid_prefixes=uid_prefixes, k=k, future=_future, + max_pending=k, time_budget=float('inf'), **kwargs))) + return future.result() + + async def _first_k_active(self, node: DHTNode, uid_prefixes: List[str], k: int, max_pending: int, + time_budget: float, future=None, **kwargs) -> TOrderedDict[str, RemoteExpert]: """ - Batch-parallel version of find_best_experts (see find_best_experts docstring for details) - The only exception is that grid_scores must now be a list of 2d tensors [batch_size, grid_size[i]] - :returns: a list of batch_size lists, each contains *up to* k_best RemoteExpert instances + Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search + + :param uid_prefixes: a list of (unique) uid prefixes ordered from highest to lowest priority + :param k: return at most *this many* active prefixes + :param max_pending: dispatches up to this many GET requests in parallel at any point in time + :param time_budget: if the procedure goes on for this many seconds, stop and return the best found experts + :returns: a ordered dict{uid_prefix -> RemoteExpert} mapping at most :k: prefixes to matching experts + The keys in the returned dict are ordered same as in uid_prefixes. """ - if grid_indices is None: - grid_indices = [torch.arange(len(dim_scores)) for dim_scores in grid_scores] - grid_scores = [dim_scores.cpu().detach() for dim_scores in grid_scores] - grid_indices = [dim_indices.cpu() for dim_indices in grid_indices] - assert len(grid_indices) == len(grid_scores), "grid_indices (if provided) must be of same length as grid_scores" - batch_size = len(grid_scores[0]) - assert all(dim_scores.ndim == 2 and dim_indices.ndim == 1 and len(dim_scores) == len(dim_indices) == batch_size - for dim, dim_scores, dim_indices in enumerate(zip(grid_scores, grid_indices))) - future, _future = MPFuture.make_pair() - self.pipe.send(('_batch_find_best_experts', [], - dict(prefix=prefix, grid_scores=grid_scores, grid_indices=grid_indices, k_best=k_best, - time_budget=time_budget, future=_future, **kwargs))) - return future if return_future else future.result() + deadline_time = time.perf_counter() + time_budget + unattempted_uids_reversed = list(reversed(uid_prefixes)) + pending_tasks: Deque[Tuple[str, asyncio.Task]] = deque() + found: TOrderedDict[str, RemoteExpert] = OrderedDict() + + while len(found) < k and (unattempted_uids_reversed or pending_tasks): + # dispatch additional tasks + while unattempted_uids_reversed and len(pending_tasks) < max_pending: + uid_prefix = unattempted_uids_reversed.pop() + pending_tasks.append((uid_prefix, asyncio.create_task(node.get(uid_prefix, **kwargs)))) + + uid_prefix, getter_task = pending_tasks.popleft() + try: + maybe_expert_data, maybe_expiration_time = await asyncio.wait_for( + getter_task, timeout=deadline_time - time.perf_counter()) + if maybe_expiration_time is not None: # found active expert + found[uid_prefix] = RemoteExpert(**maybe_expert_data) + + except asyncio.TimeoutError: + # pick up pending tasks that have have already returned result + while len(found) < k and pending_tasks: + uid_prefix, getter_task = pending_tasks.popleft() + if getter_task.done(): + maybe_expert_data, maybe_expiration_time = getter_task.result() + if maybe_expiration_time is not None: # found active expert + found[uid_prefix] = RemoteExpert(**maybe_expert_data) + + # check for remaining uids in node's local storage/cache + while len(found) < k and unattempted_uids_reversed: + uid_prefix = unattempted_uids_reversed.pop() + maybe_expert_data, maybe_expiration_time = node.get_local(uid_prefix) + if maybe_expiration_time is not None: # found active expert + found[uid_prefix] = RemoteExpert(**maybe_expert_data) + + # cancel remaining tasks + for uid_prefix, getter_task in pending_tasks: + getter_task.cancel() + + break # we ran out of time, return what we have + + except asyncio.CancelledError: + for key, getter_task in pending_tasks: + getter_task.cancel() + raise - async def _batch_find_best_experts(self, node: DHTNode, prefix: str, grid_scores: List[torch.Tensor], - grid_indices: List[torch.Tensor], k_best: int, time_budget: float = float('inf'), - future: Optional[MPFuture] = None, **kwargs) -> List[List[RemoteExpert]]: - results = await asyncio.gather(*[ - asyncio.create_task( - self._find_best_experts(node, prefix, grid_scores=grid_scores_i, grid_indices=grid_indices, - k_best=k_best, time_budget=time_budget, **kwargs) - ) for grid_scores_i in map(list, zip(*grid_scores)) - ]) if future: - future.set_result(results) - return results - + future.set_result(OrderedDict(found)) + return OrderedDict(found) From 60bc1c95ac3431ab8f4f02f8a479bfb6434c7b94 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 16 Sep 2020 14:18:03 +0300 Subject: [PATCH 08/16] rollback return_futures for now --- hivemind/dht/node.py | 53 ++++++++++++++++++++++++++------------------ 1 file changed, 31 insertions(+), 22 deletions(-) diff --git a/hivemind/dht/node.py b/hivemind/dht/node.py index aab6b4f1b..4811e67f8 100644 --- a/hivemind/dht/node.py +++ b/hivemind/dht/node.py @@ -327,17 +327,9 @@ async def get(self, key: DHTKey, latest=False, **kwargs) -> Tuple[Optional[DHTVa result = await self.get_many([key]) return result[key] - def get_local(self, key: DHTKey) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]: - """ Like DHTNode.get, but only search for key in node's local storage and cache """ - key_id = DHTID.generate(source=key) - maybe_value_bytes, maybe_expiration = self.protocol.storage.get(key_id) - maybe_cached_value, maybe_cache_expiration = self.protocol.cache.get(key_id) - if (maybe_cache_expiration or -float('inf')) > (maybe_expiration or -float('inf')): - maybe_value_bytes, maybe_expiration = maybe_cached_value, maybe_cache_expiration - return maybe_value_bytes, maybe_expiration - async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: Optional[DHTExpiration] = None, - **kwargs) -> Dict[DHTKey, Tuple[Optional[DHTValue], Optional[DHTExpiration]]]: + **kwargs) -> Dict[DHTKey, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]], + Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]: """ Traverse DHT to find a list of keys. For each key, return latest (value, expiration) or None if not found. @@ -357,8 +349,9 @@ async def get_many(self, keys: Collection[DHTKey], sufficient_expiration_time: O async def get_many_by_id( self, key_ids: Collection[DHTID], sufficient_expiration_time: Optional[DHTExpiration] = None, - num_workers: Optional[int] = None, beam_size: Optional[int] = None, _refresh_cache=True - ) -> Dict[DHTID, Tuple[Optional[DHTValue], Optional[DHTExpiration]]]: + num_workers: Optional[int] = None, beam_size: Optional[int] = None, return_futures: bool = False, + _refresh_cache=True) -> Dict[DHTID, Union[Tuple[Optional[DHTValue], Optional[DHTExpiration]], + Awaitable[Tuple[Optional[DHTValue], Optional[DHTExpiration]]]]]: """ Traverse DHT to find a list of DHTIDs. For each key, return latest (value, expiration) or None if not found. @@ -368,6 +361,9 @@ async def get_many_by_id( If min_expiration_time=float('inf'), this method will find a value with _latest_ expiration :param beam_size: maintains up to this many nearest nodes when crawling dht, default beam_size = bucket_size :param num_workers: override for default num_workers, see traverse_dht num_workers param + :param return_futures: if True, immediately return asyncio.Future for every before interacting with the nework. + The algorithm will populate these futures with (value, expiration) when it finds the corresponding key + Note: canceling a future will stop search for the corresponding key :param _refresh_cache: internal flag, whether or not to self._trigger_cache_refresh :returns: for each key: value and its expiration time. If nothing is found, returns (None, None) for that key :note: in order to check if get returned a value, please check (expiration_time is None) @@ -426,13 +422,16 @@ async def found_callback(key_id: DHTID, nearest_nodes: List[DHTID], _visited: Se get_neighbors=get_neighbors, visited_nodes={key_id: {self.node_id} for key_id in unfinished_key_ids}, found_callback=found_callback, await_all_tasks=False)) - try: - # note: this should be first time when we await something, there's no need to "try" the entire function - return {key_id: await search_result for key_id, search_result in search_results.items()} - except asyncio.CancelledError as e: # terminate remaining tasks ASAP - for key_id, search_result in search_results.items(): - search_result.finish_search() - raise e + if return_futures: + return {key_id: search_result.future for key_id, search_result in search_results.items()} + else: + try: + # note: this should be first time when we await something, there's no need to "try" the entire function + return {key_id: await search_result.future for key_id, search_result in search_results.items()} + except asyncio.CancelledError as e: # terminate remaining tasks ASAP + for key_id, search_result in search_results.items(): + search_result.future.cancel() + raise e def _reuse_finished_search_result(self, finished: _IntermediateResult): expiration_time_threshold = max(finished.expiration_time or -float('inf'), finished.sufficient_expiration_time) @@ -522,6 +521,19 @@ async def _refresh_routing_table(self, *, period: Optional[float]) -> None: await asyncio.sleep(max(0.0, period - (get_dht_time() - refresh_time))) + def local_get(self, key: DHTKey) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]: + """ (synchronous) Like DHTNode.get, but only search for key in node's local storage and cache """ + key_id = DHTID.generate(source=key) + maybe_value_bytes, maybe_expiration = self.protocol.storage.get(key_id) + maybe_cached_value, maybe_cache_expiration = self.protocol.cache.get(key_id) + if (maybe_cache_expiration or -float('inf')) > (maybe_expiration or -float('inf')): + maybe_value_bytes, maybe_expiration = maybe_cached_value, maybe_cache_expiration + return maybe_value_bytes, maybe_expiration + + def local_cache(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration) -> bool: + """ (synchronous) Add key->value pair to this node's local cache until expiration_time """ + return self.protocol.cache.store(DHTID.generate(key), self.serializer.dumps(value), expiration_time) + @dataclass(init=True, repr=True, frozen=False, order=False) class _IntermediateResult: @@ -563,6 +575,3 @@ def finished(self) -> bool: def __lt__(self, other: _IntermediateResult): """ _IntermediateResult instances will be sorted by their target expiration time """ return self.sufficient_expiration_time < other.sufficient_expiration_time - - def __await__(self): - return self.future.__await__() \ No newline at end of file From 4b9df9dc13224ad5f0eb42725262ac43deebb75d Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 16 Sep 2020 14:18:21 +0300 Subject: [PATCH 09/16] rename get_local -> local_get --- hivemind/dht/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hivemind/dht/__init__.py b/hivemind/dht/__init__.py index 7a5207ef9..4342e7977 100644 --- a/hivemind/dht/__init__.py +++ b/hivemind/dht/__init__.py @@ -291,7 +291,7 @@ async def _first_k_active(self, node: DHTNode, uid_prefixes: List[str], k: int, # check for remaining uids in node's local storage/cache while len(found) < k and unattempted_uids_reversed: uid_prefix = unattempted_uids_reversed.pop() - maybe_expert_data, maybe_expiration_time = node.get_local(uid_prefix) + maybe_expert_data, maybe_expiration_time = node.local_get(uid_prefix) if maybe_expiration_time is not None: # found active expert found[uid_prefix] = RemoteExpert(**maybe_expert_data) From 6118dc8a7da8bcd847814807b29eb0f070c0b10c Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 16 Sep 2020 14:42:49 +0300 Subject: [PATCH 10/16] rename get_local -> get_locally --- hivemind/dht/__init__.py | 2 +- hivemind/dht/node.py | 27 ++++++++++++++------------- 2 files changed, 15 insertions(+), 14 deletions(-) diff --git a/hivemind/dht/__init__.py b/hivemind/dht/__init__.py index 4342e7977..bac1fea3b 100644 --- a/hivemind/dht/__init__.py +++ b/hivemind/dht/__init__.py @@ -291,7 +291,7 @@ async def _first_k_active(self, node: DHTNode, uid_prefixes: List[str], k: int, # check for remaining uids in node's local storage/cache while len(found) < k and unattempted_uids_reversed: uid_prefix = unattempted_uids_reversed.pop() - maybe_expert_data, maybe_expiration_time = node.local_get(uid_prefix) + maybe_expert_data, maybe_expiration_time = node.get_locally(uid_prefix) if maybe_expiration_time is not None: # found active expert found[uid_prefix] = RemoteExpert(**maybe_expert_data) diff --git a/hivemind/dht/node.py b/hivemind/dht/node.py index 4811e67f8..e436e501d 100644 --- a/hivemind/dht/node.py +++ b/hivemind/dht/node.py @@ -211,6 +211,11 @@ async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, nearest_nodes_with_endpoints[query] = {node: node_to_endpoint[node] for node in nearest_nodes[:k_nearest]} return nearest_nodes_with_endpoints + def store_locally(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration, in_cache=False) -> bool: + """ (synchronous) Add key->value pair to this node's local storage or cache until expiration_time """ + chosen_storage = self.protocol.cache if in_cache else self.protocol.storage + return chosen_storage.store(DHTID.generate(key), self.serializer.dumps(value), expiration_time) + async def store(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration, **kwargs) -> bool: """ Find num_replicas best nodes to store (key, value) and store it there at least until expiration time. @@ -314,6 +319,15 @@ async def on_found(key_id: DHTID, nearest_nodes: List[DHTID], visited_nodes: Set store_task.cancel() raise e + def get_locally(self, key: DHTKey) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]: + """ (synchronous) Search for key in this node's local storage and cache, return latest (None if not found) """ + key_id = DHTID.generate(source=key) + maybe_value_bytes, maybe_expiration = self.protocol.storage.get(key_id) + maybe_cached_value, maybe_cache_expiration = self.protocol.cache.get(key_id) + if (maybe_cache_expiration or -float('inf')) > (maybe_expiration or -float('inf')): + maybe_value_bytes, maybe_expiration = maybe_cached_value, maybe_cache_expiration + return maybe_value_bytes, maybe_expiration + async def get(self, key: DHTKey, latest=False, **kwargs) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]: """ Search for a key across DHT and return either first or latest entry. @@ -521,19 +535,6 @@ async def _refresh_routing_table(self, *, period: Optional[float]) -> None: await asyncio.sleep(max(0.0, period - (get_dht_time() - refresh_time))) - def local_get(self, key: DHTKey) -> Tuple[Optional[DHTValue], Optional[DHTExpiration]]: - """ (synchronous) Like DHTNode.get, but only search for key in node's local storage and cache """ - key_id = DHTID.generate(source=key) - maybe_value_bytes, maybe_expiration = self.protocol.storage.get(key_id) - maybe_cached_value, maybe_cache_expiration = self.protocol.cache.get(key_id) - if (maybe_cache_expiration or -float('inf')) > (maybe_expiration or -float('inf')): - maybe_value_bytes, maybe_expiration = maybe_cached_value, maybe_cache_expiration - return maybe_value_bytes, maybe_expiration - - def local_cache(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration) -> bool: - """ (synchronous) Add key->value pair to this node's local cache until expiration_time """ - return self.protocol.cache.store(DHTID.generate(key), self.serializer.dumps(value), expiration_time) - @dataclass(init=True, repr=True, frozen=False, order=False) class _IntermediateResult: From 8e469579b9184ebac4b3ee16b0bdb0f30d3a05a9 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 16 Sep 2020 14:46:55 +0300 Subject: [PATCH 11/16] rollback return_future --- hivemind/dht/__init__.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/hivemind/dht/__init__.py b/hivemind/dht/__init__.py index bac1fea3b..a98c6e582 100644 --- a/hivemind/dht/__init__.py +++ b/hivemind/dht/__init__.py @@ -242,12 +242,12 @@ async def _find_best_experts(self, node: DHTNode, prefix: str, grid_scores: List future.set_result(beam_experts) return beam_experts - def first_k_active(self, uid_prefixes: List[str], k: int, **kwargs): + def first_k_active(self, uid_prefixes: List[str], k: int, return_future=False, **kwargs): """ TODO(jheuristic) remove this, also remove future arg from _first_k_active """ future, _future = MPFuture.make_pair() self.pipe.send(('_first_k_active', [], dict(uid_prefixes=uid_prefixes, k=k, future=_future, max_pending=k, time_budget=float('inf'), **kwargs))) - return future.result() + return future if return_future else future.result() async def _first_k_active(self, node: DHTNode, uid_prefixes: List[str], k: int, max_pending: int, time_budget: float, future=None, **kwargs) -> TOrderedDict[str, RemoteExpert]: From 819b07e30d0e7f4d5fa2af59661ee49a1974c11c Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 16 Sep 2020 14:59:03 +0300 Subject: [PATCH 12/16] formatting --- hivemind/dht/__init__.py | 21 ++++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/hivemind/dht/__init__.py b/hivemind/dht/__init__.py index a98c6e582..6e2916c52 100644 --- a/hivemind/dht/__init__.py +++ b/hivemind/dht/__init__.py @@ -242,15 +242,9 @@ async def _find_best_experts(self, node: DHTNode, prefix: str, grid_scores: List future.set_result(beam_experts) return beam_experts - def first_k_active(self, uid_prefixes: List[str], k: int, return_future=False, **kwargs): - """ TODO(jheuristic) remove this, also remove future arg from _first_k_active """ - future, _future = MPFuture.make_pair() - self.pipe.send(('_first_k_active', [], dict(uid_prefixes=uid_prefixes, k=k, future=_future, - max_pending=k, time_budget=float('inf'), **kwargs))) - return future if return_future else future.result() - - async def _first_k_active(self, node: DHTNode, uid_prefixes: List[str], k: int, max_pending: int, - time_budget: float, future=None, **kwargs) -> TOrderedDict[str, RemoteExpert]: + def first_k_active( + self, uid_prefixes: List[str], k: int, max_pending: Optional[int] = None, time_budget: float = float('inf'), + return_future: bool = False, **kwargs) -> Union[TOrderedDict[str, RemoteExpert], Awaitable]: """ Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search @@ -258,9 +252,18 @@ async def _first_k_active(self, node: DHTNode, uid_prefixes: List[str], k: int, :param k: return at most *this many* active prefixes :param max_pending: dispatches up to this many GET requests in parallel at any point in time :param time_budget: if the procedure goes on for this many seconds, stop and return the best found experts + :param return_future: if True, returns MPFuture that can be awaited to get result, default = return result :returns: a ordered dict{uid_prefix -> RemoteExpert} mapping at most :k: prefixes to matching experts The keys in the returned dict are ordered same as in uid_prefixes. """ + logger.warning("DHT.first_k_active is deprecated and will be removed in v0.9") + future, _future = MPFuture.make_pair() + self.pipe.send(('_first_k_active', [], dict(uid_prefixes=uid_prefixes, k=k, future=_future, + max_pending=max_pending or k, time_budget=time_budget, **kwargs))) + return future if return_future else future.result() + + async def _first_k_active(self, node: DHTNode, uid_prefixes: List[str], k: int, max_pending: int, + time_budget: float, future=None, **kwargs) -> TOrderedDict[str, RemoteExpert]: deadline_time = time.perf_counter() + time_budget unattempted_uids_reversed = list(reversed(uid_prefixes)) pending_tasks: Deque[Tuple[str, asyncio.Task]] = deque() From 76d68b0d56250e5dfdadc18fa73958d6fc2608a2 Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 16 Sep 2020 14:59:15 +0300 Subject: [PATCH 13/16] better description for store_locally --- hivemind/dht/node.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/hivemind/dht/node.py b/hivemind/dht/node.py index e436e501d..58e285139 100644 --- a/hivemind/dht/node.py +++ b/hivemind/dht/node.py @@ -212,7 +212,10 @@ async def get_neighbors(peer: DHTID, queries: Collection[DHTID]) -> Dict[DHTID, return nearest_nodes_with_endpoints def store_locally(self, key: DHTKey, value: DHTValue, expiration_time: DHTExpiration, in_cache=False) -> bool: - """ (synchronous) Add key->value pair to this node's local storage or cache until expiration_time """ + """ + (synchronous) Add key->value pair to this node's local storage or cache until expiration_time + Note: this does NOT guarantee that the key will be available to other peers. Use DHTNode.store for that. + """ chosen_storage = self.protocol.cache if in_cache else self.protocol.storage return chosen_storage.store(DHTID.generate(key), self.serializer.dumps(value), expiration_time) From 84f5f345370a21136a3e865fd30f67ec126b7ebe Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 16 Sep 2020 15:01:30 +0300 Subject: [PATCH 14/16] clarification --- hivemind/dht/protocol.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/hivemind/dht/protocol.py b/hivemind/dht/protocol.py index 1a2b37103..064e0c1f5 100644 --- a/hivemind/dht/protocol.py +++ b/hivemind/dht/protocol.py @@ -282,7 +282,7 @@ def _remove_outdated(self): def store(self, key: DHTID, value: BinaryDHTValue, expiration_time: DHTExpiration) -> bool: """ Store a (key, value) pair locally at least until expiration_time. See class docstring for details. - :returns: True if new value was stored, False it was rejected (current value is newer) + :returns: True if new value was stored, False it was rejected (e.g. if there is a newer value for that key) """ if expiration_time < get_dht_time() and not self.frozen: return False From 9120a7b6c004be1801397806189315dad3c678ae Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 30 Sep 2020 04:16:40 +0300 Subject: [PATCH 15/16] switch to prefetch_rate --- hivemind/dht/__init__.py | 27 +++++++++++++++------------ 1 file changed, 15 insertions(+), 12 deletions(-) diff --git a/hivemind/dht/__init__.py b/hivemind/dht/__init__.py index 6e2916c52..f5a167ca7 100644 --- a/hivemind/dht/__init__.py +++ b/hivemind/dht/__init__.py @@ -203,7 +203,7 @@ def find_best_experts( :param kwargs: extra keyword parameters passed to self.dht.first_k_active :returns: a list that contains *up to* k_best RemoteExpert instances """ - grid_scores = list(map(np.asanyrray, grid_scores)) + grid_scores = list(map(np.asanyarray, grid_scores)) assert all(dim_scores.ndim == 1 for dim_scores in grid_scores) future, _future = MPFuture.make_pair() @@ -211,8 +211,10 @@ def find_best_experts( time_budget=time_budget, future=_future, **kwargs))) return future if return_future else future.result() - async def _find_best_experts(self, node: DHTNode, prefix: str, grid_scores: List[np.ndarray], k_best: int, - time_budget: float, future: MPFuture, **kwargs) -> List[RemoteExpert]: + async def _find_best_experts( + self, node: DHTNode, prefix: str, grid_scores: List[np.ndarray], k_best: int, time_budget: float, + max_prefetch: Optional[int] = None, future: MPFuture = None, **kwargs) -> List[RemoteExpert]: + max_prefetch = max_prefetch or k_best deadline_time = time.perf_counter() + time_budget beam_experts: List[RemoteExpert] = [] beam: List[str] = [prefix] @@ -229,7 +231,8 @@ async def _find_best_experts(self, node: DHTNode, prefix: str, grid_scores: List # select k best candidates according to scores but only those that are still active best_alive_prefixes: TOrderedDict[str, RemoteExpert] = await self._first_k_active( - node, sorted_prefix_uids, k=k_best, time_budget=deadline_time - time.perf_counter(), **kwargs) + node, sorted_prefix_uids, k=k_best, max_prefetch=max_prefetch, + time_budget=deadline_time - time.perf_counter(), **kwargs) if not best_alive_prefixes: logger.warning(f"Grid is empty: found neither of {sorted_prefix_uids}") @@ -242,15 +245,14 @@ async def _find_best_experts(self, node: DHTNode, prefix: str, grid_scores: List future.set_result(beam_experts) return beam_experts - def first_k_active( - self, uid_prefixes: List[str], k: int, max_pending: Optional[int] = None, time_budget: float = float('inf'), - return_future: bool = False, **kwargs) -> Union[TOrderedDict[str, RemoteExpert], Awaitable]: + def first_k_active(self, uid_prefixes: List[str], k: int, prefetch_rate: int = 1, time_budget: float = float('inf'), + return_future: bool = False, **kwargs) -> Union[TOrderedDict[str, RemoteExpert], Awaitable]: """ Find k prefixes with active experts; may return less if there aren't enough; used for DMoE beam search :param uid_prefixes: a list of (unique) uid prefixes ordered from highest to lowest priority :param k: return at most *this many* active prefixes - :param max_pending: dispatches up to this many GET requests in parallel at any point in time + :param prefetch_rate: dispatch up to this many GET requests in parallel for each uid_prefix :param time_budget: if the procedure goes on for this many seconds, stop and return the best found experts :param return_future: if True, returns MPFuture that can be awaited to get result, default = return result :returns: a ordered dict{uid_prefix -> RemoteExpert} mapping at most :k: prefixes to matching experts @@ -258,11 +260,12 @@ def first_k_active( """ logger.warning("DHT.first_k_active is deprecated and will be removed in v0.9") future, _future = MPFuture.make_pair() - self.pipe.send(('_first_k_active', [], dict(uid_prefixes=uid_prefixes, k=k, future=_future, - max_pending=max_pending or k, time_budget=time_budget, **kwargs))) + self.pipe.send(('_first_k_active', [], dict( + uid_prefixes=uid_prefixes, k=k, future=_future, prefetch_rate=prefetch_rate or k, + time_budget=time_budget, **kwargs))) return future if return_future else future.result() - async def _first_k_active(self, node: DHTNode, uid_prefixes: List[str], k: int, max_pending: int, + async def _first_k_active(self, node: DHTNode, uid_prefixes: List[str], k: int, prefetch_rate: int, time_budget: float, future=None, **kwargs) -> TOrderedDict[str, RemoteExpert]: deadline_time = time.perf_counter() + time_budget unattempted_uids_reversed = list(reversed(uid_prefixes)) @@ -271,7 +274,7 @@ async def _first_k_active(self, node: DHTNode, uid_prefixes: List[str], k: int, while len(found) < k and (unattempted_uids_reversed or pending_tasks): # dispatch additional tasks - while unattempted_uids_reversed and len(pending_tasks) < max_pending: + while unattempted_uids_reversed and len(pending_tasks) < (prefetch_rate * (k - len(found))): uid_prefix = unattempted_uids_reversed.pop() pending_tasks.append((uid_prefix, asyncio.create_task(node.get(uid_prefix, **kwargs)))) From 63e2de6a9a6221e84b9b1f854c5500268dc77fec Mon Sep 17 00:00:00 2001 From: justheuristic Date: Wed, 30 Sep 2020 05:18:16 +0300 Subject: [PATCH 16/16] prefetch rate --- hivemind/dht/__init__.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/hivemind/dht/__init__.py b/hivemind/dht/__init__.py index f5a167ca7..1a8ff50ed 100644 --- a/hivemind/dht/__init__.py +++ b/hivemind/dht/__init__.py @@ -19,7 +19,7 @@ import warnings from collections import deque, OrderedDict from concurrent.futures import ThreadPoolExecutor -from typing import List, Tuple, Optional, Sequence, OrderedDict as TOrderedDict, Union, Awaitable, Dict, Deque +from typing import List, Tuple, Optional, Sequence, OrderedDict as TOrderedDict, Union, Awaitable, Dict, Deque, Set import uvloop import numpy as np @@ -213,8 +213,7 @@ def find_best_experts( async def _find_best_experts( self, node: DHTNode, prefix: str, grid_scores: List[np.ndarray], k_best: int, time_budget: float, - max_prefetch: Optional[int] = None, future: MPFuture = None, **kwargs) -> List[RemoteExpert]: - max_prefetch = max_prefetch or k_best + prefetch_rate: int = 1, future: MPFuture = None, **kwargs) -> List[RemoteExpert]: deadline_time = time.perf_counter() + time_budget beam_experts: List[RemoteExpert] = [] beam: List[str] = [prefix] @@ -231,7 +230,7 @@ async def _find_best_experts( # select k best candidates according to scores but only those that are still active best_alive_prefixes: TOrderedDict[str, RemoteExpert] = await self._first_k_active( - node, sorted_prefix_uids, k=k_best, max_prefetch=max_prefetch, + node, sorted_prefix_uids, k=k_best, prefetch_rate=prefetch_rate, time_budget=deadline_time - time.perf_counter(), **kwargs) if not best_alive_prefixes: