Skip to content

Commit

Permalink
Fix comments by @mryab
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Jul 10, 2021
1 parent a4b31b6 commit 3c7b903
Show file tree
Hide file tree
Showing 7 changed files with 44 additions and 35 deletions.
11 changes: 6 additions & 5 deletions hivemind/dht/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from __future__ import annotations

import asyncio
import ctypes
import multiprocessing as mp
import os
from concurrent.futures import ThreadPoolExecutor
Expand All @@ -27,8 +26,8 @@
from hivemind.dht.node import DHTNode, DHTID
from hivemind.dht.routing import DHTValue, DHTKey, Subkey
from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
from hivemind.p2p import P2P
from hivemind.utils import MPFuture, get_logger, switch_to_uvloop, ValueWithExpiration, await_cancelled, DHTExpiration
from hivemind.utils.networking import Hostname, Endpoint, strip_port

logger = get_logger(__name__)

Expand All @@ -50,6 +49,9 @@ class DHT(mp.Process):
:param max_workers: declare_experts and get_experts will use up to this many parallel workers
(but no more than one per key)
:param expiration: experts declared from this node expire after this many seconds (default = 5 minutes)
:param record_validators: instances of RecordValidatorBase used for signing and validating stored records.
The validators will be combined using the CompositeValidator class. It merges them when possible
(according to their `.merge_with()` policies) and orders them according to the `.priority` properties.
:param shutdown_timeout: when calling .shutdown, wait for up to this many seconds before terminating
:param kwargs: any other params will be forwarded to DHTNode and hivemind.p2p.P2P upon creation
"""
Expand All @@ -58,7 +60,7 @@ class DHT(mp.Process):
def __init__(self, p2p: Optional[P2P] = None,
initial_peers: Optional[Sequence[Union[Multiaddr, str]]] = None,
*, start: bool, daemon: bool = True, max_workers: Optional[int] = None,
parallel_rpc: Optional[int] = None, record_validators: Iterable[RecordValidatorBase] = (),
record_validators: Iterable[RecordValidatorBase] = (),
shutdown_timeout: float = 3, **kwargs):
super().__init__()

Expand All @@ -69,7 +71,6 @@ def __init__(self, p2p: Optional[P2P] = None,
self.initial_peers = initial_peers
self.kwargs = kwargs
self.max_workers = max_workers
self.parallel_rpc = parallel_rpc

self._record_validator = CompositeValidator(record_validators)
self._inner_pipe, self._outer_pipe = mp.Pipe(duplex=True)
Expand All @@ -86,7 +87,7 @@ def run(self) -> None:
with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:
async def _run():
self._node = await DHTNode.create(
p2p=self.p2p, initial_peers=self.initial_peers, parallel_rpc=self.parallel_rpc,
p2p=self.p2p, initial_peers=self.initial_peers,
num_workers=self.max_workers or 1, record_validator=self._record_validator,
**self.kwargs)
self.ready.set()
Expand Down
7 changes: 4 additions & 3 deletions hivemind/dht/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from collections import defaultdict, Counter
from dataclasses import dataclass, field
from functools import partial
from typing import Any, Awaitable, Callable, Collection, DefaultDict, Dict, List, Optional, Sequence, Set, Tuple, Union
from typing import (Any, Awaitable, Callable, Collection, DefaultDict, Dict, List, Optional, Sequence, Set, Tuple,
Type, Union)

from multiaddr import Multiaddr
from sortedcontainers import SortedSet
Expand Down Expand Up @@ -128,7 +129,7 @@ async def create(
if False, this node will refuse any incoming request, effectively being only a "client"
:param record_validator: instance of RecordValidatorBase used for signing and validating stored records
:param authorizer: instance of AuthorizerBase used for signing and validating requests and response
for following some authorization protocol
for a given authorization protocol
:param kwargs: extra parameters for an internally created instance of hivemind.p2p.P2P.
Should be empty if the P2P instance is provided in the constructor
"""
Expand Down Expand Up @@ -666,7 +667,7 @@ class _SearchState:
expiration_time: Optional[DHTExpiration] = None # best expiration time so far
source_node_id: Optional[DHTID] = None # node that gave us the value
future: asyncio.Future[Optional[ValueWithExpiration[DHTValue]]] = field(default_factory=asyncio.Future)
serializer: type(SerializerBase) = MSGPackSerializer
serializer: Type[SerializerBase] = MSGPackSerializer
record_validator: Optional[RecordValidatorBase] = None

def add_candidate(self, candidate: Optional[ValueWithExpiration[Union[BinaryDHTValue, DictionaryDHTValue]]],
Expand Down
8 changes: 4 additions & 4 deletions hivemind/dht/protocol.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
""" RPC protocol that provides nodes a way to communicate with each other. Based on gRPC.AIO. """
""" RPC protocol that provides nodes a way to communicate with each other """
from __future__ import annotations

import asyncio
Expand Down Expand Up @@ -95,7 +95,7 @@ async def call_ping(self, peer: Endpoint, validate: bool = False, strict: bool =
response = await self.get_stub(peer).rpc_ping(ping_request, timeout=self.wait_timeout)
time_responded = get_dht_time()
except Exception as e:
logger.debug(f"DHTProtocol failed to ping {peer}: {e}")
logger.exception(f"DHTProtocol failed to ping {peer}")
response = None
responded = bool(response and response.peer and response.peer.node_id)

Expand Down Expand Up @@ -189,7 +189,7 @@ async def call_store(self, peer: Endpoint, keys: Sequence[DHTID],
asyncio.create_task(self.update_routing_table(peer_id, peer, responded=True))
return response.store_ok
except Exception as e:
logger.debug(f"DHTProtocol failed to store at {peer}: {e}")
logger.exception(f"DHTProtocol failed to store at {peer}")
asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))
return None

Expand Down Expand Up @@ -275,7 +275,7 @@ async def call_find(self, peer: Endpoint, keys: Collection[DHTID]) -> Optional[D

return output
except Exception as e:
logger.debug(f"DHTProtocol failed to find at {peer}: {e}")
logger.exception(f"DHTProtocol failed to find at {peer}")
asyncio.create_task(self.update_routing_table(self.routing_table.get(endpoint=peer), peer, responded=False))

async def rpc_find(self, request: dht_pb2.FindRequest, context: P2PContext) -> dht_pb2.FindResponse:
Expand Down
17 changes: 10 additions & 7 deletions hivemind/p2p/p2p_daemon.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ async def create(cls,
use_relay: bool = True, use_relay_hop: bool = False,
use_relay_discovery: bool = False, use_auto_relay: bool = False, relay_hop_limit: int = 0,
quiet: bool = True,
ping_n_retries: int = 5, ping_retry_delay: float = 0.4) -> 'P2P':
ping_n_attempts: int = 5, ping_delay: float = 0.4) -> 'P2P':
"""
Start a new p2pd process and connect to it.
:param initial_peers: List of bootstrap peers
Expand All @@ -103,6 +103,9 @@ async def create(cls,
:param use_auto_relay: enables autorelay
:param relay_hop_limit: sets the hop limit for hop relays
:param quiet: make the daemon process quiet
:param ping_n_attempts: try to ping the daemon with this number of attempts after starting it
:param ping_delay: wait for ``ping_delay * (2 ** (k - 1))`` seconds before the k-th attempt to ping the daemon
(in particular, wait for ``ping_delay`` seconds before the first attempt)
:return: a wrapper for the p2p daemon
"""

Expand Down Expand Up @@ -139,13 +142,13 @@ async def create(cls,
self._alive = True
self._client = p2pclient.Client(self._daemon_listen_maddr, self._client_listen_maddr)

await self._ping_daemon_with_retries(ping_n_retries, ping_retry_delay)
await self._ping_daemon_with_retries(ping_n_attempts, ping_delay)

return self

async def _ping_daemon_with_retries(self, ping_n_retries: int, ping_retry_delay: float) -> None:
for try_number in range(ping_n_retries):
await asyncio.sleep(ping_retry_delay * (2 ** try_number))
async def _ping_daemon_with_retries(self, ping_n_attempts: int, ping_delay: float) -> None:
for try_number in range(ping_n_attempts):
await asyncio.sleep(ping_delay * (2 ** try_number))

if self._child.poll() is not None: # Process died
break
Expand All @@ -154,8 +157,8 @@ async def _ping_daemon_with_retries(self, ping_n_retries: int, ping_retry_delay:
await self._ping_daemon()
break
except Exception as e:
if try_number == ping_n_retries - 1:
logger.error(f'Failed to ping p2pd: {e}')
if try_number == ping_n_attempts - 1:
logger.exception('Failed to ping p2pd that has just started')
await self.shutdown()
raise

Expand Down
9 changes: 5 additions & 4 deletions hivemind/p2p/servicer.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,8 @@ class StubBase:
"""

def __init__(self, p2p: P2P, peer: PeerID):
self.p2p = p2p
self.peer = peer
self._p2p = p2p
self._peer = peer


class Servicer:
Expand Down Expand Up @@ -66,10 +66,11 @@ def __init__(self):

@staticmethod
def _make_rpc_caller(handler: RPCHandler):
async def caller(stub: StubBase, request: handler.request_type,
# This method will be added to a new Stub type (a subclass of StubBase)
async def caller(self: StubBase, request: handler.request_type,
timeout: Optional[float] = None) -> handler.response_type:
return await asyncio.wait_for(
stub.p2p.call_unary_handler(stub.peer, handler.handle_name, request, handler.response_type),
self._p2p.call_unary_handler(self._peer, handler.handle_name, request, handler.response_type),
timeout=timeout)

caller.__name__ = handler.method_name
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ def cleanup_children():

children = psutil.Process().children(recursive=True)
if children:
logger.info(f'Cleaning up {len(children)} child processes')
logger.info(f'Cleaning up {len(children)} leftover child processes')
for child in children:
with suppress(psutil.NoSuchProcess):
child.terminate()
Expand Down
25 changes: 14 additions & 11 deletions tests/test_dht_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,23 @@
import random
import signal
import threading
from itertools import chain, product
from typing import Dict, List, Optional, Sequence, Tuple
from itertools import product
from typing import Dict, List, Optional, Tuple

import numpy as np
import pytest
from multiaddr import Multiaddr

import hivemind
from hivemind import get_dht_time, replace_port
from hivemind import get_dht_time
from hivemind.dht.node import DHTID, Endpoint, DHTNode
from hivemind.dht.protocol import DHTProtocol, ValidationError
from hivemind.dht.protocol import DHTProtocol
from hivemind.dht.storage import DictionaryDHTValue
from hivemind.p2p import P2P, PeerID
from hivemind.utils.networking import LOCALHOST
from hivemind.utils.logging import get_logger


logger = get_logger(__name__)


def maddrs_to_peer_ids(maddrs: List[Multiaddr]) -> List[PeerID]:
Expand All @@ -34,7 +37,7 @@ def run_protocol_listener(dhtid: DHTID, maddr_conn: mp.connection.Connection,
protocol = loop.run_until_complete(DHTProtocol.create(
p2p, dhtid, bucket_size=20, depth_modulo=5, num_replicas=3, wait_timeout=5))

print(f"Started peer id={protocol.node_id} maddrs={maddrs}", flush=True)
logger.info(f"Started peer id={protocol.node_id} maddrs={maddrs}")

if initial_peers is not None:
for endpoint in maddrs_to_peer_ids(initial_peers):
Expand All @@ -44,7 +47,7 @@ def run_protocol_listener(dhtid: DHTID, maddr_conn: mp.connection.Connection,

async def shutdown():
await p2p.shutdown()
print(f"Finished peer id={protocol.node_id} maddrs={maddrs}", flush=True)
logger.info(f"Finished peer id={protocol.node_id} maddrs={maddrs}")
loop.stop()

loop.add_signal_handler(signal.SIGTERM, lambda: loop.create_task(shutdown()))
Expand Down Expand Up @@ -77,7 +80,7 @@ def test_dht_protocol():
p2p = loop.run_until_complete(P2P.create(initial_peers=peer1_maddrs))
protocol = loop.run_until_complete(DHTProtocol.create(
p2p, DHTID.generate(), bucket_size=20, depth_modulo=5, wait_timeout=5, num_replicas=3, listen=listen))
print(f"Self id={protocol.node_id}", flush=True)
logger.info(f"Self id={protocol.node_id}")

assert loop.run_until_complete(protocol.call_ping(peer1_endpoint)) == peer1_id

Expand Down Expand Up @@ -176,7 +179,7 @@ def run_node(initial_peers: List[Multiaddr], info_queue: mp.Queue):
asyncio.set_event_loop(asyncio.new_event_loop())
loop = asyncio.get_event_loop()

node = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers, ping_n_retries=10))
node = loop.run_until_complete(DHTNode.create(initial_peers=initial_peers))
maddrs = loop.run_until_complete(node.get_visible_maddrs())

info_queue.put((node.node_id, node.endpoint, maddrs))
Expand Down Expand Up @@ -292,9 +295,9 @@ def test_dht_node():
jaccard_denominator += k_nearest

accuracy = accuracy_numerator / accuracy_denominator
print("Top-1 accuracy:", accuracy) # should be 98-100%
logger.info(f"Top-1 accuracy: {accuracy}") # should be 98-100%
jaccard_index = jaccard_numerator / jaccard_denominator
print("Jaccard index (intersection over union):", jaccard_index) # should be 95-100%
logger.info(f"Jaccard index (intersection over union): {jaccard_index}") # should be 95-100%
assert accuracy >= 0.9, f"Top-1 accuracy only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"
assert jaccard_index >= 0.9, f"Jaccard index only {accuracy} ({accuracy_numerator} / {accuracy_denominator})"

Expand Down

0 comments on commit 3c7b903

Please sign in to comment.