diff --git a/.circleci/config.yml b/.circleci/config.yml index b1ab5978e..daaac9b6a 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,5 +1,10 @@ version: 2.1 +parameters: + go-version: + type: string + default: 1.16.2 + jobs: build-and-test-py37: docker: @@ -9,6 +14,11 @@ jobs: - restore_cache: keys: - py37-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }} + - v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }} + - run: | + wget https://golang.org/dl/go<< pipeline.parameters.go-version >>.linux-amd64.tar.gz -O go.tar.gz + tar -C ~/ -xzf go.tar.gz + echo "export PATH=~/go/bin:$PATH" >> $BASH_ENV - run: pip install -r requirements.txt - run: pip install -r requirements-dev.txt - save_cache: @@ -29,6 +39,10 @@ jobs: - restore_cache: keys: - py38-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }} + - run: | + wget https://golang.org/dl/go<< pipeline.parameters.go-version >>.linux-amd64.tar.gz -O go.tar.gz + tar -C ~/ -xzf go.tar.gz + echo "export PATH=~/go/bin:$PATH" >> $BASH_ENV - run: pip install -r requirements.txt - run: pip install -r requirements-dev.txt - save_cache: @@ -49,6 +63,10 @@ jobs: - restore_cache: keys: - py39-v1-{{ checksum "requirements.txt" }}-{{ checksum "requirements-dev.txt" }} + - run: | + wget https://golang.org/dl/go<< pipeline.parameters.go-version >>.linux-amd64.tar.gz -O go.tar.gz + tar -C ~/ -xzf go.tar.gz + echo "export PATH=~/go/bin:$PATH" >> $BASH_ENV - run: pip install -r requirements.txt - run: pip install -r requirements-dev.txt - save_cache: diff --git a/.gitignore b/.gitignore index 965aa8972..61e239d1c 100644 --- a/.gitignore +++ b/.gitignore @@ -78,3 +78,6 @@ debian/files # protobuf stuff hivemind/proto/*_pb2* + +# libp2p-daemon binary +hivemind/hivemind_cli/p2pd diff --git a/hivemind/__init__.py b/hivemind/__init__.py index ebbfa0588..3fdfc625b 100644 --- a/hivemind/__init__.py +++ b/hivemind/__init__.py @@ -1,5 +1,6 @@ from hivemind.client import * from hivemind.dht import * +from hivemind.p2p import * from hivemind.server import * from hivemind.utils import * from hivemind.optim import * diff --git a/hivemind/p2p/__init__.py b/hivemind/p2p/__init__.py new file mode 100644 index 000000000..6bae0b8bd --- /dev/null +++ b/hivemind/p2p/__init__.py @@ -0,0 +1 @@ +from hivemind.p2p.p2p_daemon import P2P diff --git a/hivemind/p2p/p2p_daemon.py b/hivemind/p2p/p2p_daemon.py new file mode 100644 index 000000000..fa521716f --- /dev/null +++ b/hivemind/p2p/p2p_daemon.py @@ -0,0 +1,369 @@ +import asyncio +from copy import deepcopy +from dataclasses import dataclass +from importlib.resources import path +from subprocess import Popen +from typing import List, Optional + +import google.protobuf +from multiaddr import Multiaddr + +import hivemind.hivemind_cli as cli +import hivemind.p2p.p2p_daemon_bindings.p2pclient as p2pclient +from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, StreamInfo +from hivemind.proto import p2pd_pb2 +from hivemind.utils import MSGPackSerializer +from hivemind.utils.logging import get_logger +from hivemind.utils.networking import find_open_port + +logger = get_logger(__name__) + + +P2PD_FILENAME = 'p2pd' +NUM_RETRIES = 3 +RETRY_DELAY = 0.4 + + +class P2PInterruptedError(Exception): + pass + + +@dataclass(frozen=False) +class P2PContext(object): + id: str + port: int + handle_name: str + peer_id: PeerID = None + peer_addr: Multiaddr = None + + +class P2P: + """ + Forks a child process and executes p2pd command with given arguments. + Can be used for peer to peer communication and procedure calls. + Sends SIGKILL to the child in destructor. + """ + + HEADER_LEN = 8 + BYTEORDER = 'big' + PB_HEADER_LEN = 1 + RESULT_MESSAGE = b'\x00' + ERROR_MESSAGE = b'\x01' + DHT_MODE_MAPPING = { + 'dht': {'dht': 1}, + 'dht_server': {'dhtServer': 1}, + 'dht_client': {'dhtClient': 1}, + } + FORCE_REACHABILITY_MAPPING = { + 'public': {'forceReachabilityPublic': 1}, + 'private': {'forceReachabilityPrivate': 1}, + } + + def __init__(self): + self._child = None + self._alive = False + self._listen_task = None + self._server_stopped = asyncio.Event() + + @classmethod + async def create(cls, *args, quic: bool = True, tls: bool = True, conn_manager: bool = True, + dht_mode: str = 'dht_server', force_reachability: Optional[str] = None, + nat_port_map: bool = True, auto_nat: bool = True, bootstrap: bool = False, + bootstrap_peers: Optional[List[str]] = None, use_global_ipfs: bool = False, host_port: int = None, + daemon_listen_port: int = None, **kwargs): + """ + Start a new p2pd process and connect to it. + :param args: + :param quic: Enables the QUIC transport + :param tls: Enables TLS1.3 channel security protocol + :param conn_manager: Enables the Connection Manager + :param dht_mode: DHT mode (dht_client/dht_server/dht) + :param force_reachability: Force reachability mode (public/private) + :param nat_port_map: Enables NAT port mapping + :param auto_nat: Enables the AutoNAT service + :param bootstrap: Connects to bootstrap peers and bootstraps the dht if enabled + :param bootstrap_peers: List of bootstrap peers; defaults to the IPFS DHT peers + :param use_global_ipfs: Bootstrap to global ipfs (works only if bootstrap=True and bootstrap_peers=None) + :param host_port: port for p2p network + :param daemon_listen_port: port for connection daemon and client binding + :param kwargs: + :return: new wrapper for p2p daemon + """ + + assert not (bootstrap and bootstrap_peers is None and not use_global_ipfs), \ + 'Trying to create with bootstrap node without bootstrap nodes list. ' \ + 'It is very dangerous, because p2pd connects to global ipfs and it is very unstable. ' \ + 'If you really want this, pass use_global_ipfs=True' + assert not (bootstrap_peers is not None and use_global_ipfs), \ + 'Non empty bootstrap_nodes and use_global_ipfs=True are incompatible.' \ + 'Choose one option: your nodes list (preferable) or global ipfs (very unstable)' + + self = cls() + with path(cli, P2PD_FILENAME) as p: + p2pd_path = p + bootstrap_peers = cls._make_bootstrap_peers(bootstrap_peers) + dht = cls.DHT_MODE_MAPPING.get(dht_mode, {'dht': 0}) + force_reachability = cls.FORCE_REACHABILITY_MAPPING.get(force_reachability, {}) + proc_args = self._make_process_args( + str(p2pd_path), *args, + quic=quic, tls=tls, connManager=conn_manager, + natPortMap=nat_port_map, autonat=auto_nat, + b=bootstrap, **{**bootstrap_peers, **dht, **force_reachability, **kwargs}) + self._assign_daemon_ports(host_port, daemon_listen_port) + + for try_count in range(NUM_RETRIES): + try: + self._initialize(proc_args) + await self._wait_for_client(RETRY_DELAY * (2 ** try_count)) + break + except Exception as e: + logger.debug(f"Failed to initialize p2p daemon: {e}") + self._terminate() + if try_count == NUM_RETRIES - 1: + raise + self._assign_daemon_ports() + + return self + + @classmethod + async def replicate(cls, daemon_listen_port: int, host_port: int): + """ + Connect to existing p2p daemon + :param daemon_listen_port: port for connection daemon and client binding + :param host_port: port for p2p network + :return: new wrapper for existing p2p daemon + """ + + self = cls() + # There is no child under control + # Use external already running p2pd + self._child = None + self._alive = True + self._assign_daemon_ports(host_port, daemon_listen_port) + self._client_listen_port = find_open_port() + self._client = p2pclient.Client( + Multiaddr(f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'), + Multiaddr(f'/ip4/127.0.0.1/tcp/{self._client_listen_port}')) + await self._wait_for_client() + return self + + async def wait_for_at_least_n_peers(self, n_peers, attempts=3, delay=1): + for _ in range(attempts): + peers = await self._client.list_peers() + if len(peers) >= n_peers: + return + await asyncio.sleep(delay) + + raise RuntimeError('Not enough peers') + + def _initialize(self, proc_args: List[str]) -> None: + proc_args = deepcopy(proc_args) + proc_args.extend(self._make_process_args( + hostAddrs=f'/ip4/0.0.0.0/tcp/{self._host_port},/ip4/0.0.0.0/udp/{self._host_port}/quic', + listen=f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}' + )) + self._child = Popen(args=proc_args, encoding="utf8") + self._alive = True + self._client_listen_port = find_open_port() + self._client = p2pclient.Client( + Multiaddr(f'/ip4/127.0.0.1/tcp/{self._daemon_listen_port}'), + Multiaddr(f'/ip4/127.0.0.1/tcp/{self._client_listen_port}')) + + async def _wait_for_client(self, delay=0): + await asyncio.sleep(delay) + encoded = await self._client.identify() + self.id = encoded[0].to_base58() + + def _assign_daemon_ports(self, host_port=None, daemon_listen_port=None): + if host_port is None: + host_port = find_open_port() + if daemon_listen_port is None: + daemon_listen_port = find_open_port() + while daemon_listen_port == host_port: + daemon_listen_port = find_open_port() + + self._host_port, self._daemon_listen_port = host_port, daemon_listen_port + + @staticmethod + async def send_raw_data(byte_str, writer): + request = len(byte_str).to_bytes(P2P.HEADER_LEN, P2P.BYTEORDER) + byte_str + writer.write(request) + + @staticmethod + async def send_msgpack(data, writer): + raw_data = MSGPackSerializer.dumps(data) + await P2P.send_raw_data(raw_data, writer) + + @staticmethod + async def send_protobuf(protobuf, out_proto_type, writer): + if type(protobuf) != out_proto_type: + raise TypeError('Unary handler returned protobuf of wrong type.') + if out_proto_type == p2pd_pb2.RPCError: + await P2P.send_raw_data(P2P.ERROR_MESSAGE, writer) + else: + await P2P.send_raw_data(P2P.RESULT_MESSAGE, writer) + + await P2P.send_raw_data(protobuf.SerializeToString(), writer) + + @staticmethod + async def receive_raw_data(reader: asyncio.StreamReader, header_len=HEADER_LEN): + header = await reader.readexactly(header_len) + content_length = int.from_bytes(header, P2P.BYTEORDER) + data = await reader.readexactly(content_length) + return data + + @staticmethod + async def receive_msgpack(reader): + return MSGPackSerializer.loads(await P2P.receive_raw_data(reader)) + + @staticmethod + async def receive_protobuf(in_proto_type, reader): + msg_type = await P2P.receive_raw_data(reader) + if msg_type == P2P.RESULT_MESSAGE: + protobuf = in_proto_type() + protobuf.ParseFromString(await P2P.receive_raw_data(reader)) + return protobuf, None + elif msg_type == P2P.ERROR_MESSAGE: + protobuf = p2pd_pb2.RPCError() + protobuf.ParseFromString(await P2P.receive_raw_data(reader)) + return None, protobuf + else: + raise TypeError('Invalid Protobuf message type') + + @staticmethod + def _handle_stream(handle): + async def do_handle_stream(stream_info, reader, writer): + try: + request = await P2P.receive_raw_data(reader) + except asyncio.IncompleteReadError: + logger.debug("Incomplete read while receiving request from peer") + writer.close() + return + try: + result = handle(request) + await P2P.send_raw_data(result, writer) + finally: + writer.close() + + return do_handle_stream + + @staticmethod + def _handle_unary_stream(handle, context, in_proto_type, out_proto_type): + async def watchdog(reader: asyncio.StreamReader): + await reader.read(n=1) + raise P2PInterruptedError() + + async def do_handle_unary_stream( + stream_info: StreamInfo, + reader: asyncio.StreamReader, + writer: asyncio.StreamWriter) -> None: + try: + try: + request = await P2P.receive_protobuf(in_proto_type, reader) + except asyncio.IncompleteReadError: + logger.debug("Incomplete read while receiving request from peer") + return + except google.protobuf.message.DecodeError as error: + logger.exception(error) + return + + context.peer_id, context.peer_addr = stream_info.peer_id, stream_info.addr + done, pending = await asyncio.wait([watchdog(reader), handle(request, context)], + return_when=asyncio.FIRST_COMPLETED) + try: + result = done.pop().result() + await P2P.send_protobuf(result, out_proto_type, writer) + except P2PInterruptedError: + pass + except Exception as exc: + error = p2pd_pb2.RPCError(message=str(exc)) + await P2P.send_protobuf(error, p2pd_pb2.RPCError, writer) + finally: + pending_task = pending.pop() + pending_task.cancel() + try: + await pending_task + except asyncio.CancelledError: + pass + finally: + writer.close() + + return do_handle_unary_stream + + def start_listening(self): + async def listen(): + async with self._client.listen(): + await self._server_stopped.wait() + + self._listen_task = asyncio.create_task(listen()) + + async def stop_listening(self): + if self._listen_task is not None: + self._server_stopped.set() + self._listen_task.cancel() + try: + await self._listen_task + except asyncio.CancelledError: + self._listen_task = None + self._server_stopped.clear() + + async def add_stream_handler(self, name, handle): + if self._listen_task is None: + self.start_listening() + await self._client.stream_handler(name, self._handle_stream(handle)) + + async def add_unary_handler(self, name, handle, in_proto_type, out_proto_type): + if self._listen_task is None: + self.start_listening() + context = P2PContext(id=self.id, port=self._host_port, handle_name=name) + await self._client.stream_handler( + name, P2P._handle_unary_stream(handle, context, in_proto_type, out_proto_type)) + + async def call_peer_handler(self, peer_id, handler_name, input_data): + libp2p_peer_id = PeerID.from_base58(peer_id) + stream_info, reader, writer = await self._client.stream_open(libp2p_peer_id, (handler_name,)) + try: + await P2P.send_raw_data(input_data, writer) + return await P2P.receive_raw_data(reader) + finally: + writer.close() + + def __del__(self): + self._terminate() + + @property + def is_alive(self): + return self._alive + + async def shutdown(self): + await asyncio.get_event_loop().run_in_executor(None, self._terminate) + + def _terminate(self): + self._alive = False + if self._child is not None and self._child.poll() is None: + self._child.kill() + self._child.wait() + + @staticmethod + def _make_process_args(*args, **kwargs) -> List[str]: + proc_args = [] + proc_args.extend( + str(entry) for entry in args + ) + proc_args.extend( + f'-{key}={P2P._convert_process_arg_type(value)}' if value is not None else f'-{key}' + for key, value in kwargs.items() + ) + return proc_args + + @staticmethod + def _convert_process_arg_type(val): + if isinstance(val, bool): + return 1 if val else 0 + return val + + @staticmethod + def _make_bootstrap_peers(nodes): + if nodes is None: + return {} + return {'bootstrapPeers': ','.join(nodes)} diff --git a/hivemind/p2p/p2p_daemon_bindings/__init__.py b/hivemind/p2p/p2p_daemon_bindings/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/hivemind/p2p/p2p_daemon_bindings/control.py b/hivemind/p2p/p2p_daemon_bindings/control.py new file mode 100644 index 000000000..2002338a2 --- /dev/null +++ b/hivemind/p2p/p2p_daemon_bindings/control.py @@ -0,0 +1,210 @@ +""" +Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings +Licence: MIT +Author: Kevin Mai-Husan Chia +""" + +import asyncio +from contextlib import asynccontextmanager +from typing import (AsyncIterator, Awaitable, Callable, Dict, Iterable, + Sequence, Tuple) + +from multiaddr import Multiaddr, protocols + +from hivemind.p2p.p2p_daemon_bindings.datastructures import (PeerID, PeerInfo, + StreamInfo) +from hivemind.p2p.p2p_daemon_bindings.utils import (DispatchFailure, + raise_if_failed, + read_pbmsg_safe, + write_pbmsg) +from hivemind.proto import p2pd_pb2 as p2pd_pb +from hivemind.utils.logging import get_logger + +StreamHandler = Callable[[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter], Awaitable[None]] + +SUPPORT_CONN_PROTOCOLS = ( + protocols.P_IP4, + # protocols.P_IP6, + protocols.P_UNIX, +) +SUPPORTED_PROTOS = ( + protocols.protocol_with_code(proto) for proto in SUPPORT_CONN_PROTOCOLS +) +logger = get_logger(__name__) + + +def parse_conn_protocol(maddr: Multiaddr) -> int: + proto_codes = set(proto.code for proto in maddr.protocols()) + proto_cand = proto_codes.intersection(SUPPORT_CONN_PROTOCOLS) + if len(proto_cand) != 1: + raise ValueError( + f"connection protocol should be only one protocol out of {SUPPORTED_PROTOS}" + f", maddr={maddr}" + ) + return tuple(proto_cand)[0] + + +class DaemonConnector: + DEFAULT_CONTROL_MADDR = "/unix/tmp/p2pd.sock" + + def __init__(self, control_maddr: Multiaddr = Multiaddr(DEFAULT_CONTROL_MADDR)) -> None: + self.control_maddr = control_maddr + self.proto_code = parse_conn_protocol(self.control_maddr) + + async def open_connection(self) -> (asyncio.StreamReader, asyncio.StreamWriter): + if self.proto_code == protocols.P_UNIX: + control_path = self.control_maddr.value_for_protocol(protocols.P_UNIX) + logger.debug(f"DaemonConnector {self} opens connection to {self.control_maddr}") + return await asyncio.open_unix_connection(control_path) + elif self.proto_code == protocols.P_IP4: + host = self.control_maddr.value_for_protocol(protocols.P_IP4) + port = int(self.control_maddr.value_for_protocol(protocols.P_TCP)) + return await asyncio.open_connection(host, port) + else: + raise ValueError( + f"Protocol not supported: {protocols.protocol_with_code(self.proto_code)}" + ) + + +class ControlClient: + DEFAULT_LISTEN_MADDR = "/unix/tmp/p2pclient.sock" + + def __init__( + self, daemon_connector: DaemonConnector, listen_maddr: Multiaddr = Multiaddr(DEFAULT_LISTEN_MADDR) + ) -> None: + self.listen_maddr = listen_maddr + self.daemon_connector = daemon_connector + self.handlers: Dict[str, StreamHandler] = {} + + async def _handler(self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter): + pb_stream_info = p2pd_pb.StreamInfo() # type: ignore + await read_pbmsg_safe(reader, pb_stream_info) + stream_info = StreamInfo.from_protobuf(pb_stream_info) + logger.debug(f"New incoming stream: {stream_info}") + try: + handler = self.handlers[stream_info.proto] + except KeyError as e: + # should never enter here... daemon should reject the stream for us. + writer.close() + raise DispatchFailure(e) + await handler(stream_info, reader, writer) + + @asynccontextmanager + async def listen(self) -> AsyncIterator["ControlClient"]: + proto_code = parse_conn_protocol(self.listen_maddr) + if proto_code == protocols.P_UNIX: + listen_path = self.listen_maddr.value_for_protocol(protocols.P_UNIX) + server = await asyncio.start_unix_server(self._handler, path=listen_path) + elif proto_code == protocols.P_IP4: + host = self.listen_maddr.value_for_protocol(protocols.P_IP4) + port = int(self.listen_maddr.value_for_protocol(protocols.P_TCP)) + server = await asyncio.start_server(self._handler, port=port, host=host) + else: + raise ValueError( + f"Protocol not supported: {protocols.protocol_with_code(proto_code)}" + ) + + async with server: + logger.info(f"DaemonConnector {self} starts listening to {self.listen_maddr}") + yield self + + logger.info(f"DaemonConnector {self} closed") + + async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]: + reader, writer = await self.daemon_connector.open_connection() + req = p2pd_pb.Request(type=p2pd_pb.Request.IDENTIFY) + await write_pbmsg(writer, req) + + resp = p2pd_pb.Response() # type: ignore + await read_pbmsg_safe(reader, resp) + writer.close() + + raise_if_failed(resp) + peer_id_bytes = resp.identify.id + maddrs_bytes = resp.identify.addrs + + maddrs = tuple(Multiaddr(maddr_bytes) for maddr_bytes in maddrs_bytes) + peer_id = PeerID(peer_id_bytes) + + return peer_id, maddrs + + async def connect(self, peer_id: PeerID, maddrs: Iterable[Multiaddr]) -> None: + reader, writer = await self.daemon_connector.open_connection() + + maddrs_bytes = [i.to_bytes() for i in maddrs] + connect_req = p2pd_pb.ConnectRequest( + peer=peer_id.to_bytes(), addrs=maddrs_bytes + ) + req = p2pd_pb.Request(type=p2pd_pb.Request.CONNECT, connect=connect_req) + await write_pbmsg(writer, req) + + resp = p2pd_pb.Response() # type: ignore + await read_pbmsg_safe(reader, resp) + writer.close() + raise_if_failed(resp) + + async def list_peers(self) -> Tuple[PeerInfo, ...]: + req = p2pd_pb.Request(type=p2pd_pb.Request.LIST_PEERS) + reader, writer = await self.daemon_connector.open_connection() + await write_pbmsg(writer, req) + resp = p2pd_pb.Response() # type: ignore + await read_pbmsg_safe(reader, resp) + writer.close() + raise_if_failed(resp) + + peers = tuple(PeerInfo.from_protobuf(pinfo) for pinfo in resp.peers) + return peers + + async def disconnect(self, peer_id: PeerID) -> None: + disconnect_req = p2pd_pb.DisconnectRequest(peer=peer_id.to_bytes()) + req = p2pd_pb.Request( + type=p2pd_pb.Request.DISCONNECT, disconnect=disconnect_req + ) + reader, writer = await self.daemon_connector.open_connection() + await write_pbmsg(writer, req) + resp = p2pd_pb.Response() # type: ignore + await read_pbmsg_safe(reader, resp) + writer.close() + raise_if_failed(resp) + + async def stream_open( + self, peer_id: PeerID, protocols: Sequence[str] + ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]: + reader, writer = await self.daemon_connector.open_connection() + + stream_open_req = p2pd_pb.StreamOpenRequest( + peer=peer_id.to_bytes(), proto=list(protocols) + ) + req = p2pd_pb.Request( + type=p2pd_pb.Request.STREAM_OPEN, streamOpen=stream_open_req + ) + await write_pbmsg(writer, req) + + resp = p2pd_pb.Response() # type: ignore + await read_pbmsg_safe(reader, resp) + raise_if_failed(resp) + + pb_stream_info = resp.streamInfo + stream_info = StreamInfo.from_protobuf(pb_stream_info) + + return stream_info, reader, writer + + async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None: + reader, writer = await self.daemon_connector.open_connection() + + listen_path_maddr_bytes = self.listen_maddr.to_bytes() + stream_handler_req = p2pd_pb.StreamHandlerRequest( + addr=listen_path_maddr_bytes, proto=[proto] + ) + req = p2pd_pb.Request( + type=p2pd_pb.Request.STREAM_HANDLER, streamHandler=stream_handler_req + ) + await write_pbmsg(writer, req) + + resp = p2pd_pb.Response() # type: ignore + await read_pbmsg_safe(reader, resp) + writer.close() + raise_if_failed(resp) + + # if success, add the handler to the dict + self.handlers[proto] = handler_cb diff --git a/hivemind/p2p/p2p_daemon_bindings/datastructures.py b/hivemind/p2p/p2p_daemon_bindings/datastructures.py new file mode 100644 index 000000000..dab58c408 --- /dev/null +++ b/hivemind/p2p/p2p_daemon_bindings/datastructures.py @@ -0,0 +1,170 @@ +""" +Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings +Licence: MIT +Author: Kevin Mai-Husan Chia +""" + +import hashlib +from typing import Any, Sequence, Union + +import base58 +import multihash +from multiaddr import Multiaddr, protocols + +from hivemind.proto import p2pd_pb2 + +# NOTE: On inlining... +# See: https://github.com/libp2p/specs/issues/138 +# NOTE: enabling to be interoperable w/ the Go implementation +ENABLE_INLINING = True +MAX_INLINE_KEY_LENGTH = 42 + +IDENTITY_MULTIHASH_CODE = 0x00 + +if ENABLE_INLINING: + + class IdentityHash: + def __init__(self) -> None: + self._digest = bytearray() + + def update(self, input: bytes) -> None: + self._digest += input + + def digest(self) -> bytes: + return self._digest + + multihash.FuncReg.register( + IDENTITY_MULTIHASH_CODE, "identity", hash_new=IdentityHash + ) + + +class PeerID: + def __init__(self, peer_id_bytes: bytes) -> None: + self._bytes = peer_id_bytes + self._xor_id = int(sha256_digest(self._bytes).hex(), 16) + self._b58_str = base58.b58encode(self._bytes).decode() + + @property + def xor_id(self) -> int: + return self._xor_id + + def to_bytes(self) -> bytes: + return self._bytes + + def to_base58(self) -> str: + return self._b58_str + + def __repr__(self) -> str: + return f"" + + def __str__(self): + return self.to_base58() + + def pretty(self): + return self.to_base58() + + def to_string(self): + return self.to_base58() + + def __eq__(self, other: object) -> bool: + if isinstance(other, str): + return self.to_base58() == other + elif isinstance(other, bytes): + return self._bytes == other + elif isinstance(other, PeerID): + return self._bytes == other._bytes + else: + return False + + def __hash__(self) -> int: + return hash(self._bytes) + + @classmethod + def from_base58(cls, base58_id: str) -> "PeerID": + peer_id_bytes = base58.b58decode(base58_id) + return cls(peer_id_bytes) + + +def sha256_digest(data: Union[str, bytes]) -> bytes: + if isinstance(data, str): + data = data.encode("utf8") + return hashlib.sha256(data).digest() + + +class StreamInfo: + def __init__(self, peer_id: PeerID, addr: Multiaddr, proto: str) -> None: + self.peer_id = peer_id + self.addr = addr + self.proto = proto + + def __repr__(self) -> str: + return ( + f"" + ) + + def to_protobuf(self) -> p2pd_pb2.StreamInfo: + pb_msg = p2pd_pb2.StreamInfo( + peer=self.peer_id.to_bytes(), addr=self.addr.to_bytes(), proto=self.proto + ) + return pb_msg + + @classmethod + def from_protobuf(cls, pb_msg: p2pd_pb2.StreamInfo) -> "StreamInfo": + stream_info = cls( + peer_id=PeerID(pb_msg.peer), addr=Multiaddr(pb_msg.addr), proto=pb_msg.proto + ) + return stream_info + + +class PeerInfo: + def __init__(self, peer_id: PeerID, addrs: Sequence[Multiaddr]) -> None: + self.peer_id = peer_id + self.addrs = list(addrs) + + def __eq__(self, other: Any) -> bool: + return ( + isinstance(other, PeerInfo) + and self.peer_id == other.peer_id + and self.addrs == other.addrs + ) + + @classmethod + def from_protobuf(cls, peer_info_pb: p2pd_pb2.PeerInfo) -> "PeerInfo": + peer_id = PeerID(peer_info_pb.id) + addrs = [Multiaddr(addr) for addr in peer_info_pb.addrs] + return PeerInfo(peer_id, addrs) + + def __str__(self): + return f"{self.peer_id.pretty()} {','.join(str(a) for a in self.addrs)}" + + +class InvalidAddrError(ValueError): + pass + + +def info_from_p2p_addr(addr: Multiaddr) -> PeerInfo: + if addr is None: + raise InvalidAddrError("`addr` should not be `None`") + + parts = addr.split() + if not parts: + raise InvalidAddrError( + f"`parts`={parts} should at least have a protocol `P_P2P`" + ) + + p2p_part = parts[-1] + last_protocol_code = p2p_part.protocols()[0].code + if last_protocol_code != protocols.P_P2P: + raise InvalidAddrError( + f"The last protocol should be `P_P2P` instead of `{last_protocol_code}`" + ) + + # make sure the /p2p value parses as a peer.ID + peer_id_str: str = p2p_part.value_for_protocol(protocols.P_P2P) + peer_id = PeerID.from_base58(peer_id_str) + + # we might have received just an / p2p part, which means there's no addr. + if len(parts) > 1: + addr = Multiaddr.join(*parts[:-1]) + + return PeerInfo(peer_id, [addr]) diff --git a/hivemind/p2p/p2p_daemon_bindings/p2pclient.py b/hivemind/p2p/p2p_daemon_bindings/p2pclient.py new file mode 100644 index 000000000..c1fe97808 --- /dev/null +++ b/hivemind/p2p/p2p_daemon_bindings/p2pclient.py @@ -0,0 +1,85 @@ +""" +Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings +Licence: MIT +Author: Kevin Mai-Husan Chia +""" + +import asyncio +from contextlib import asynccontextmanager +from typing import AsyncIterator, Iterable, Sequence, Tuple + +from multiaddr import Multiaddr + +from hivemind.p2p.p2p_daemon_bindings.control import (ControlClient, + DaemonConnector, + StreamHandler) +from hivemind.p2p.p2p_daemon_bindings.datastructures import (PeerID, PeerInfo, + StreamInfo) + + +class Client: + control: ControlClient + + def __init__( + self, control_maddr: Multiaddr = None, listen_maddr: Multiaddr = None + ) -> None: + daemon_connector = DaemonConnector(control_maddr=control_maddr) + self.control = ControlClient( + daemon_connector=daemon_connector, listen_maddr=listen_maddr + ) + + @asynccontextmanager + async def listen(self) -> AsyncIterator["Client"]: + """ + Starts to listen incoming connections for handlers registered via stream_handler. + :return: + """ + async with self.control.listen(): + yield self + + async def identify(self) -> Tuple[PeerID, Tuple[Multiaddr, ...]]: + """ + Get current node peer id and list of addresses + """ + return await self.control.identify() + + async def connect(self, peer_id: PeerID, maddrs: Iterable[Multiaddr]) -> None: + """ + Connect to p2p node with specified addresses and peer id. + :peer_id: node peer id you want connect to + :maddrs: node multiaddresses you want connect to. Of course, it must be reachable. + """ + await self.control.connect(peer_id=peer_id, maddrs=maddrs) + + async def list_peers(self) -> Tuple[PeerInfo, ...]: + """ + Get list of peers that node connect to + """ + return await self.control.list_peers() + + async def disconnect(self, peer_id: PeerID) -> None: + """ + Disconnect from node with specified peer id + :peer_id: node peer id you want disconnect from + """ + await self.control.disconnect(peer_id=peer_id) + + async def stream_open( + self, peer_id: PeerID, protocols: Sequence[str] + ) -> Tuple[StreamInfo, asyncio.StreamReader, asyncio.StreamWriter]: + """ + Open a stream to call other peer (with peer_id) handler for specified protocols + :peer_id: other peer id + :protocols: list of protocols for other peer handling + :return: Returns tuple of stream info (info about connection to second peer) and reader/writer + """ + return await self.control.stream_open(peer_id=peer_id, protocols=protocols) + + async def stream_handler(self, proto: str, handler_cb: StreamHandler) -> None: + """ + Register a stream handler + :param proto: protocols that handler serves + :param handler_cb: handler callback + :return: + """ + await self.control.stream_handler(proto=proto, handler_cb=handler_cb) diff --git a/hivemind/p2p/p2p_daemon_bindings/utils.py b/hivemind/p2p/p2p_daemon_bindings/utils.py new file mode 100644 index 000000000..2a0d5b97c --- /dev/null +++ b/hivemind/p2p/p2p_daemon_bindings/utils.py @@ -0,0 +1,73 @@ +""" +Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings +Licence: MIT +Author: Kevin Mai-Husan Chia +""" + +import asyncio + +from google.protobuf.message import Message as PBMessage + +from hivemind.proto import p2pd_pb2 as p2pd_pb + +DEFAULT_MAX_BITS: int = 64 + + +class ControlFailure(Exception): + pass + + +class DispatchFailure(Exception): + pass + + +async def write_unsigned_varint(stream: asyncio.StreamWriter, integer: int, max_bits: int = DEFAULT_MAX_BITS) -> None: + max_int = 1 << max_bits + if integer < 0: + raise ValueError(f"negative integer: {integer}") + if integer >= max_int: + raise ValueError(f"integer too large: {integer}") + while True: + value = integer & 0x7F + integer >>= 7 + if integer != 0: + value |= 0x80 + byte = value.to_bytes(1, "big") + stream.write(byte) + if integer == 0: + break + + +async def read_unsigned_varint(stream: asyncio.StreamReader, max_bits: int = DEFAULT_MAX_BITS) -> int: + max_int = 1 << max_bits + iteration = 0 + result = 0 + has_next = True + while has_next: + data = await stream.readexactly(1) + c = data[0] + value = c & 0x7F + result |= value << (iteration * 7) + has_next = (c & 0x80) != 0 + iteration += 1 + if result >= max_int: + raise ValueError(f"Varint overflowed: {result}") + return result + + +def raise_if_failed(response: p2pd_pb.Response) -> None: + if response.type == p2pd_pb.Response.ERROR: + raise ControlFailure(f"Connect failed. msg={response.error.msg}") + + +async def write_pbmsg(stream: asyncio.StreamWriter, pbmsg: PBMessage) -> None: + size = pbmsg.ByteSize() + await write_unsigned_varint(stream, size) + msg_bytes: bytes = pbmsg.SerializeToString() + stream.write(msg_bytes) + + +async def read_pbmsg_safe(stream: asyncio.StreamReader, pbmsg: PBMessage) -> None: + len_msg_bytes = await read_unsigned_varint(stream) + msg_bytes = await stream.readexactly(len_msg_bytes) + pbmsg.ParseFromString(msg_bytes) diff --git a/hivemind/proto/p2pd.proto b/hivemind/proto/p2pd.proto new file mode 100644 index 000000000..373c6d8e9 --- /dev/null +++ b/hivemind/proto/p2pd.proto @@ -0,0 +1,166 @@ +//Originally taken from: https://github.com/mhchia/py-libp2p-daemon-bindings +//Licence: MIT +//Author: Kevin Mai-Husan Chia + +syntax = "proto2"; + +package p2pclient.p2pd.pb; + +message Request { + enum Type { + IDENTIFY = 0; + CONNECT = 1; + STREAM_OPEN = 2; + STREAM_HANDLER = 3; + DHT = 4; + LIST_PEERS = 5; + CONNMANAGER = 6; + DISCONNECT = 7; + PUBSUB = 8; + } + + required Type type = 1; + + optional ConnectRequest connect = 2; + optional StreamOpenRequest streamOpen = 3; + optional StreamHandlerRequest streamHandler = 4; + optional DHTRequest dht = 5; + optional ConnManagerRequest connManager = 6; + optional DisconnectRequest disconnect = 7; + optional PSRequest pubsub = 8; +} + +message Response { + enum Type { + OK = 0; + ERROR = 1; + } + + required Type type = 1; + optional ErrorResponse error = 2; + optional StreamInfo streamInfo = 3; + optional IdentifyResponse identify = 4; + optional DHTResponse dht = 5; + repeated PeerInfo peers = 6; + optional PSResponse pubsub = 7; +} + +message IdentifyResponse { + required bytes id = 1; + repeated bytes addrs = 2; +} + +message ConnectRequest { + required bytes peer = 1; + repeated bytes addrs = 2; + optional int64 timeout = 3; +} + +message StreamOpenRequest { + required bytes peer = 1; + repeated string proto = 2; + optional int64 timeout = 3; +} + +message StreamHandlerRequest { + required bytes addr = 1; + repeated string proto = 2; +} + +message ErrorResponse { + required string msg = 1; +} + +message StreamInfo { + required bytes peer = 1; + required bytes addr = 2; + required string proto = 3; +} + +message DHTRequest { + enum Type { + FIND_PEER = 0; + FIND_PEERS_CONNECTED_TO_PEER = 1; + FIND_PROVIDERS = 2; + GET_CLOSEST_PEERS = 3; + GET_PUBLIC_KEY = 4; + GET_VALUE = 5; + SEARCH_VALUE = 6; + PUT_VALUE = 7; + PROVIDE = 8; + } + + required Type type = 1; + optional bytes peer = 2; + optional bytes cid = 3; + optional bytes key = 4; + optional bytes value = 5; + optional int32 count = 6; + optional int64 timeout = 7; +} + +message DHTResponse { + enum Type { + BEGIN = 0; + VALUE = 1; + END = 2; + } + + required Type type = 1; + optional PeerInfo peer = 2; + optional bytes value = 3; +} + +message PeerInfo { + required bytes id = 1; + repeated bytes addrs = 2; +} + +message ConnManagerRequest { + enum Type { + TAG_PEER = 0; + UNTAG_PEER = 1; + TRIM = 2; + } + + required Type type = 1; + + optional bytes peer = 2; + optional string tag = 3; + optional int64 weight = 4; +} + +message DisconnectRequest { + required bytes peer = 1; +} + +message PSRequest { + enum Type { + GET_TOPICS = 0; + LIST_PEERS = 1; + PUBLISH = 2; + SUBSCRIBE = 3; + } + + required Type type = 1; + optional string topic = 2; + optional bytes data = 3; +} + +message PSMessage { + optional bytes from_id = 1; + optional bytes data = 2; + optional bytes seqno = 3; + repeated string topicIDs = 4; + optional bytes signature = 5; + optional bytes key = 6; +} + +message PSResponse { + repeated string topics = 1; + repeated bytes peerIDs = 2; +} + +message RPCError { + required string message = 1; +} diff --git a/requirements.txt b/requirements.txt index 7e2e84d93..9ce07c9eb 100644 --- a/requirements.txt +++ b/requirements.txt @@ -10,5 +10,7 @@ grpcio>=1.33.2 grpcio-tools>=1.33.2 protobuf>=3.12.2 configargparse>=1.2.3 +multiaddr>=0.0.9 +pymultihash>=0.8.2 cryptography>=3.4.6 pydantic>=1.8.1 diff --git a/setup.py b/setup.py index 53cb6b77d..2cdef0d36 100644 --- a/setup.py +++ b/setup.py @@ -1,13 +1,35 @@ import codecs import glob +import hashlib import os import re +import shlex +import subprocess +import tarfile +import tempfile +import urllib.request +from packaging import version from pkg_resources import parse_requirements -from setuptools import setup, find_packages +from setuptools import find_packages, setup from setuptools.command.develop import develop from setuptools.command.install import install +P2PD_VERSION = 'v0.3.1' +P2PD_CHECKSUM = '8810097959db720208cdc9f2945804a4' +LIBP2P_TAR_URL = f'https://github.com/learning-at-home/go-libp2p-daemon/archive/refs/tags/{P2PD_VERSION}.tar.gz' + + +here = os.path.abspath(os.path.dirname(__file__)) + + +def md5(fname, chunk_size=4096): + hash_md5 = hashlib.md5() + with open(fname, "rb") as f: + for chunk in iter(lambda: f.read(chunk_size), b""): + hash_md5.update(chunk) + return hash_md5.hexdigest() + def proto_compile(output_path): import grpc_tools.protoc @@ -28,20 +50,59 @@ def proto_compile(output_path): file.truncate() -class ProtoCompileInstall(install): +def libp2p_build_install(): + try: + result = subprocess.run("go version", capture_output=True, shell=True).stdout.decode('ascii', 'replace') + m = re.search(r'^go version go([\d.]+)', result) + v = m.group(1) + + if version.parse(v) < version.parse("1.13"): + raise EnvironmentError(f'Newer version of go required: must be >= 1.13, found {version}') + + except FileNotFoundError: + raise FileNotFoundError('Could not find golang installation') + + with tempfile.TemporaryDirectory() as tempdir: + dest = os.path.join(tempdir, 'libp2p-daemon.tar.gz') + urllib.request.urlretrieve(LIBP2P_TAR_URL, dest) + + with tarfile.open(dest, 'r:gz') as tar: + tar.extractall(tempdir) + + result = subprocess.run(f'go build -o {shlex.quote(os.path.join(here, "hivemind", "hivemind_cli", "p2pd"))}', + cwd=os.path.join(tempdir, f'go-libp2p-daemon-{P2PD_VERSION[1:]}', 'p2pd'), shell=True) + + if result.returncode: + raise RuntimeError('Failed to build or install libp2p-daemon:' + f' exited with status code: {result.returncode}') + + +def libp2p_download_install(): + install_path = os.path.join(here, 'hivemind', 'hivemind_cli') + binary_path = os.path.join(install_path, 'p2pd') + if 'p2pd' not in os.listdir(install_path) or md5(binary_path) != P2PD_CHECKSUM: + print('Downloading Peer to Peer Daemon') + url = f'https://github.com/learning-at-home/go-libp2p-daemon/releases/download/{P2PD_VERSION}/p2pd' + urllib.request.urlretrieve(url, binary_path) + os.chmod(binary_path, 0o777) + if md5(binary_path) != P2PD_CHECKSUM: + raise RuntimeError(f'Downloaded p2pd binary from {url} does not match with md5 checksum') + + +class Install(install): def run(self): + libp2p_download_install() proto_compile(os.path.join(self.build_lib, 'hivemind', 'proto')) super().run() -class ProtoCompileDevelop(develop): +class Develop(develop): def run(self): + libp2p_build_install() proto_compile(os.path.join('hivemind', 'proto')) super().run() -here = os.path.abspath(os.path.dirname(__file__)) - with open('requirements.txt') as requirements_file: install_requires = list(map(str, parse_requirements(requirements_file))) @@ -63,7 +124,7 @@ def run(self): setup( name='hivemind', version=version_string, - cmdclass={'install': ProtoCompileInstall, 'develop': ProtoCompileDevelop}, + cmdclass={'install': Install, 'develop': Develop}, description='Decentralized deep learning in PyTorch', long_description='Decentralized deep learning in PyTorch. Built to train giant models on ' 'thousands of volunteers across the world.', diff --git a/tests/test_p2p_daemon.py b/tests/test_p2p_daemon.py new file mode 100644 index 000000000..f8464d2a7 --- /dev/null +++ b/tests/test_p2p_daemon.py @@ -0,0 +1,440 @@ +import asyncio +import multiprocessing as mp +import subprocess +from functools import partial +from typing import List + +import numpy as np +import pytest +import torch + +from hivemind.p2p import P2P +from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID +from hivemind.proto import dht_pb2, runtime_pb2 +from hivemind.utils import MSGPackSerializer +from hivemind.utils.compression import deserialize_torch_tensor, serialize_torch_tensor + + +def is_process_running(pid: int) -> bool: + return subprocess.run(["ps", "-p", str(pid)], stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL).returncode == 0 + + +async def replicate_if_needed(p2p: P2P, replicate: bool): + return await P2P.replicate(p2p._daemon_listen_port, p2p._host_port) if replicate else p2p + + +def bootstrap_addr(host_port, id_): + return f'/ip4/127.0.0.1/tcp/{host_port}/p2p/{id_}' + + +def bootstrap_from(daemons: List[P2P]) -> List[str]: + return [bootstrap_addr(d._host_port, d.id) for d in daemons] + + +@pytest.mark.asyncio +async def test_daemon_killed_on_del(): + p2p_daemon = await P2P.create() + + child_pid = p2p_daemon._child.pid + assert is_process_running(child_pid) + + await p2p_daemon.shutdown() + assert not is_process_running(child_pid) + + +@pytest.mark.asyncio +async def test_server_client_connection(): + server = await P2P.create() + peers = await server._client.list_peers() + assert len(peers) == 0 + + nodes = bootstrap_from([server]) + client = await P2P.create(bootstrap=True, bootstrap_peers=nodes) + await client.wait_for_at_least_n_peers(1) + + peers = await client._client.list_peers() + assert len(peers) == 1 + peers = await server._client.list_peers() + assert len(peers) == 1 + + +@pytest.mark.asyncio +async def test_daemon_replica_does_not_affect_primary(): + p2p_daemon = await P2P.create() + p2p_replica = await P2P.replicate(p2p_daemon._daemon_listen_port, p2p_daemon._host_port) + + child_pid = p2p_daemon._child.pid + assert is_process_running(child_pid) + + await p2p_replica.shutdown() + assert is_process_running(child_pid) + + await p2p_daemon.shutdown() + assert not is_process_running(child_pid) + + +def handle_square(x): + x = MSGPackSerializer.loads(x) + return MSGPackSerializer.dumps(x ** 2) + + +def handle_add(args): + args = MSGPackSerializer.loads(args) + result = args[0] + for i in range(1, len(args)): + result = result + args[i] + return MSGPackSerializer.dumps(result) + + +def handle_square_torch(x): + tensor = runtime_pb2.Tensor() + tensor.ParseFromString(x) + tensor = deserialize_torch_tensor(tensor) + result = tensor ** 2 + return serialize_torch_tensor(result).SerializeToString() + + +def handle_add_torch(args): + args = MSGPackSerializer.loads(args) + tensor = runtime_pb2.Tensor() + tensor.ParseFromString(args[0]) + result = deserialize_torch_tensor(tensor) + + for i in range(1, len(args)): + tensor = runtime_pb2.Tensor() + tensor.ParseFromString(args[i]) + result = result + deserialize_torch_tensor(tensor) + + return serialize_torch_tensor(result).SerializeToString() + + +def handle_add_torch_with_exc(args): + try: + return handle_add_torch(args) + except Exception: + return b'something went wrong :(' + + +@pytest.mark.parametrize( + 'should_cancel,replicate', [ + (True, False), + (True, True), + (False, False), + (False, True), + ] +) +@pytest.mark.asyncio +async def test_call_unary_handler(should_cancel, replicate, handle_name="handle"): + handler_cancelled = False + + async def ping_handler(request, context): + try: + await asyncio.sleep(2) + except asyncio.CancelledError: + nonlocal handler_cancelled + handler_cancelled = True + return dht_pb2.PingResponse( + peer=dht_pb2.NodeInfo( + node_id=context.id.encode(), rpc_port=context.port), + sender_endpoint=context.handle_name, available=True) + + server_primary = await P2P.create() + server = await replicate_if_needed(server_primary, replicate) + server_pid = server_primary._child.pid + await server.add_unary_handler(handle_name, ping_handler, dht_pb2.PingRequest, + dht_pb2.PingResponse) + assert is_process_running(server_pid) + + nodes = bootstrap_from([server]) + client_primary = await P2P.create(bootstrap=True, bootstrap_peers=nodes) + client = await replicate_if_needed(client_primary, replicate) + client_pid = client_primary._child.pid + assert is_process_running(client_pid) + + ping_request = dht_pb2.PingRequest( + peer=dht_pb2.NodeInfo(node_id=client.id.encode(), rpc_port=client._host_port), + validate=True) + expected_response = dht_pb2.PingResponse( + peer=dht_pb2.NodeInfo(node_id=server.id.encode(), rpc_port=server._host_port), + sender_endpoint=handle_name, available=True) + + await client.wait_for_at_least_n_peers(1) + libp2p_server_id = PeerID.from_base58(server.id) + stream_info, reader, writer = await client._client.stream_open(libp2p_server_id, (handle_name,)) + + await P2P.send_protobuf(ping_request, dht_pb2.PingRequest, writer) + + if should_cancel: + writer.close() + await asyncio.sleep(1) + assert handler_cancelled + else: + result, err = await P2P.receive_protobuf(dht_pb2.PingResponse, reader) + assert err is None + assert result == expected_response + assert not handler_cancelled + + await server.stop_listening() + await server_primary.shutdown() + assert not is_process_running(server_pid) + + await client_primary.shutdown() + assert not is_process_running(client_pid) + + +@pytest.mark.asyncio +async def test_call_unary_handler_error(handle_name="handle"): + async def error_handler(request, context): + raise ValueError('boom') + + server = await P2P.create() + server_pid = server._child.pid + await server.add_unary_handler(handle_name, error_handler, dht_pb2.PingRequest, dht_pb2.PingResponse) + assert is_process_running(server_pid) + + nodes = bootstrap_from([server]) + client = await P2P.create(bootstrap=True, bootstrap_peers=nodes) + client_pid = client._child.pid + assert is_process_running(client_pid) + await client.wait_for_at_least_n_peers(1) + + ping_request = dht_pb2.PingRequest( + peer=dht_pb2.NodeInfo(node_id=client.id.encode(), rpc_port=client._host_port), + validate=True) + libp2p_server_id = PeerID.from_base58(server.id) + stream_info, reader, writer = await client._client.stream_open(libp2p_server_id, (handle_name,)) + + await P2P.send_protobuf(ping_request, dht_pb2.PingRequest, writer) + result, err = await P2P.receive_protobuf(dht_pb2.PingResponse, reader) + assert result is None + assert err.message == 'boom' + + await server.stop_listening() + await server.shutdown() + await client.shutdown() + + +@pytest.mark.parametrize( + "test_input,expected,handle", + [ + pytest.param(10, 100, handle_square, id="square_integer"), + pytest.param((1, 2), 3, handle_add, id="add_integers"), + pytest.param(([1, 2, 3], [12, 13]), [1, 2, 3, 12, 13], handle_add, id="add_lists"), + pytest.param(2, 8, lambda x: MSGPackSerializer.dumps(MSGPackSerializer.loads(x) ** 3), id="lambda") + ] +) +@pytest.mark.asyncio +async def test_call_peer_single_process(test_input, expected, handle, handler_name="handle"): + server = await P2P.create() + server_pid = server._child.pid + await server.add_stream_handler(handler_name, handle) + assert is_process_running(server_pid) + + nodes = bootstrap_from([server]) + client = await P2P.create(bootstrap=True, bootstrap_peers=nodes) + client_pid = client._child.pid + assert is_process_running(client_pid) + + await client.wait_for_at_least_n_peers(1) + + test_input_msgp = MSGPackSerializer.dumps(test_input) + result_msgp = await client.call_peer_handler(server.id, handler_name, test_input_msgp) + result = MSGPackSerializer.loads(result_msgp) + assert result == expected + + await server.stop_listening() + await server.shutdown() + assert not is_process_running(server_pid) + + await client.shutdown() + assert not is_process_running(client_pid) + + +async def run_server(handler_name, server_side, client_side, response_received): + server = await P2P.create() + server_pid = server._child.pid + await server.add_stream_handler(handler_name, handle_square) + assert is_process_running(server_pid) + + server_side.send(server.id) + server_side.send(server._host_port) + while response_received.value == 0: + await asyncio.sleep(0.5) + + await server.stop_listening() + await server.shutdown() + assert not is_process_running(server_pid) + + +def server_target(handler_name, server_side, client_side, response_received): + asyncio.run(run_server(handler_name, server_side, client_side, response_received)) + + +@pytest.mark.asyncio +async def test_call_peer_different_processes(): + handler_name = "square" + test_input = 2 + + server_side, client_side = mp.Pipe() + response_received = mp.Value(np.ctypeslib.as_ctypes_type(np.int32)) + response_received.value = 0 + + proc = mp.Process(target=server_target, args=(handler_name, server_side, client_side, response_received)) + proc.start() + + peer_id = client_side.recv() + peer_port = client_side.recv() + + nodes = [bootstrap_addr(peer_port, peer_id)] + client = await P2P.create(bootstrap=True, bootstrap_peers=nodes) + client_pid = client._child.pid + assert is_process_running(client_pid) + + await client.wait_for_at_least_n_peers(1) + + test_input_msgp = MSGPackSerializer.dumps(2) + result_msgp = await client.call_peer_handler(peer_id, handler_name, test_input_msgp) + result = MSGPackSerializer.loads(result_msgp) + assert np.allclose(result, test_input ** 2) + response_received.value = 1 + + await client.shutdown() + assert not is_process_running(client_pid) + + proc.join() + + +@pytest.mark.parametrize( + "test_input,expected", + [ + pytest.param(torch.tensor([2]), torch.tensor(4)), + pytest.param( + torch.tensor([[1.0, 2.0], [0.5, 0.1]]), + torch.tensor([[1.0, 2.0], [0.5, 0.1]]) ** 2), + ] +) +@pytest.mark.asyncio +async def test_call_peer_torch_square(test_input, expected, handler_name="handle"): + handle = handle_square_torch + server = await P2P.create() + await server.add_stream_handler(handler_name, handle) + + nodes = bootstrap_from([server]) + client = await P2P.create(bootstrap=True, bootstrap_peers=nodes) + + await client.wait_for_at_least_n_peers(1) + + inp = serialize_torch_tensor(test_input).SerializeToString() + result_pb = await client.call_peer_handler(server.id, handler_name, inp) + result = runtime_pb2.Tensor() + result.ParseFromString(result_pb) + result = deserialize_torch_tensor(result) + assert torch.allclose(result, expected) + + await server.stop_listening() + await server.shutdown() + await client.shutdown() + + +@pytest.mark.parametrize( + "test_input,expected", + [ + pytest.param([torch.tensor([1]), torch.tensor([2])], torch.tensor([3])), + pytest.param( + [torch.tensor([[0.1, 0.2], [0.3, 0.4]]), torch.tensor([[1.1, 1.2], [1.3, 1.4]])], + torch.tensor([[1.2, 1.4], [1.6, 1.8]])), + ] +) +@pytest.mark.asyncio +async def test_call_peer_torch_add(test_input, expected, handler_name="handle"): + handle = handle_add_torch + server = await P2P.create() + await server.add_stream_handler(handler_name, handle) + + nodes = bootstrap_from([server]) + client = await P2P.create(bootstrap=True, bootstrap_peers=nodes) + + await client.wait_for_at_least_n_peers(1) + + inp = [serialize_torch_tensor(i).SerializeToString() for i in test_input] + inp_msgp = MSGPackSerializer.dumps(inp) + result_pb = await client.call_peer_handler(server.id, handler_name, inp_msgp) + result = runtime_pb2.Tensor() + result.ParseFromString(result_pb) + result = deserialize_torch_tensor(result) + assert torch.allclose(result, expected) + + await server.stop_listening() + await server.shutdown() + await client.shutdown() + + +@pytest.mark.parametrize( + "replicate", + [ + pytest.param(False, id="primary"), + pytest.param(True, id="replica"), + ] +) +@pytest.mark.asyncio +async def test_call_peer_error(replicate, handler_name="handle"): + server_primary = await P2P.create() + server = await replicate_if_needed(server_primary, replicate) + await server.add_stream_handler(handler_name, handle_add_torch_with_exc) + + nodes = bootstrap_from([server]) + client_primary = await P2P.create(bootstrap=True, bootstrap_peers=nodes) + client = await replicate_if_needed(client_primary, replicate) + + await client.wait_for_at_least_n_peers(1) + + inp = [serialize_torch_tensor(i).SerializeToString() for i in [torch.zeros((2, 3)), torch.zeros((3, 2))]] + inp_msgp = MSGPackSerializer.dumps(inp) + result = await client.call_peer_handler(server.id, handler_name, inp_msgp) + assert result == b'something went wrong :(' + + await server.stop_listening() + await server_primary.shutdown() + await client_primary.shutdown() + + +@pytest.mark.asyncio +async def test_handlers_on_different_replicas(handler_name="handle"): + def handler(arg, key): + return key + + server_primary = await P2P.create(bootstrap=False) + server_id = server_primary.id + await server_primary.add_stream_handler(handler_name, partial(handler, key=b'primary')) + + server_replica1 = await replicate_if_needed(server_primary, True) + await server_replica1.add_stream_handler(handler_name + '1', partial(handler, key=b'replica1')) + + server_replica2 = await replicate_if_needed(server_primary, True) + await server_replica2.add_stream_handler(handler_name + '2', partial(handler, key=b'replica2')) + + nodes = bootstrap_from([server_primary]) + client = await P2P.create(bootstrap=True, bootstrap_peers=nodes) + await client.wait_for_at_least_n_peers(1) + + result = await client.call_peer_handler(server_id, handler_name, b'1') + assert result == b"primary" + + result = await client.call_peer_handler(server_id, handler_name + '1', b'2') + assert result == b"replica1" + + result = await client.call_peer_handler(server_id, handler_name + '2', b'3') + assert result == b"replica2" + + await server_replica1.stop_listening() + await server_replica2.stop_listening() + + # Primary does not handle replicas protocols + with pytest.raises(asyncio.IncompleteReadError): + await client.call_peer_handler(server_id, handler_name + '1', b'') + with pytest.raises(asyncio.IncompleteReadError): + await client.call_peer_handler(server_id, handler_name + '2', b'') + + await server_primary.stop_listening() + await server_primary.shutdown() + await client.shutdown() diff --git a/tests/test_p2p_daemon_bindings.py b/tests/test_p2p_daemon_bindings.py new file mode 100644 index 000000000..172bded54 --- /dev/null +++ b/tests/test_p2p_daemon_bindings.py @@ -0,0 +1,559 @@ +import asyncio +import io +from contextlib import AsyncExitStack + +import pytest +from google.protobuf.message import EncodeError +from multiaddr import Multiaddr, protocols + +from hivemind.p2p.p2p_daemon_bindings.control import ControlClient, DaemonConnector, parse_conn_protocol +from hivemind.p2p.p2p_daemon_bindings.datastructures import PeerID, PeerInfo, StreamInfo +from hivemind.p2p.p2p_daemon_bindings.utils import (ControlFailure, raise_if_failed, read_pbmsg_safe, + read_unsigned_varint, write_pbmsg, write_unsigned_varint) +from hivemind.proto import p2pd_pb2 as p2pd_pb +from test_utils import make_p2pd_pair_ip4, connect_safe + + +def test_raise_if_failed_raises(): + resp = p2pd_pb.Response() + resp.type = p2pd_pb.Response.ERROR + with pytest.raises(ControlFailure): + raise_if_failed(resp) + + +def test_raise_if_failed_not_raises(): + resp = p2pd_pb.Response() + resp.type = p2pd_pb.Response.OK + raise_if_failed(resp) + + +PAIRS_INT_SERIALIZED_VALID = ( + (0, b"\x00"), + (1, b"\x01"), + (128, b"\x80\x01"), + (2 ** 32, b"\x80\x80\x80\x80\x10"), + (2 ** 64 - 1, b"\xff\xff\xff\xff\xff\xff\xff\xff\xff\x01"), +) + +PAIRS_INT_SERIALIZED_OVERFLOW = ( + (2 ** 64, b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x02"), + (2 ** 64 + 1, b"\x81\x80\x80\x80\x80\x80\x80\x80\x80\x02"), + ( + 2 ** 128, + b"\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x80\x04", + ), +) + +PEER_ID_STRING = "QmS5QmciTXXnCUCyxud5eWFenUMAmvAWSDa1c7dvdXRMZ7" +PEER_ID_BYTES = b'\x12 7\x87F.[\xb5\xb1o\xe5*\xc7\xb9\xbb\x11:"Z|j2\x8ad\x1b\xa6\xe5= timeout: + # timeout + assert False, f"{coro_func} still failed after `{timeout}` seconds" + await asyncio.sleep(0.01) + + +class Daemon: + control_maddr = None + proc_daemon = None + log_filename = "" + f_log = None + closed = None + + def __init__( + self, control_maddr, enable_control, enable_connmgr, enable_dht, enable_pubsub + ): + self.control_maddr = control_maddr + self.enable_control = enable_control + self.enable_connmgr = enable_connmgr + self.enable_dht = enable_dht + self.enable_pubsub = enable_pubsub + self.is_closed = False + self._start_logging() + self._run() + + def _start_logging(self): + name_control_maddr = str(self.control_maddr).replace("/", "_").replace(".", "_") + self.log_filename = f"/tmp/log_p2pd{name_control_maddr}.txt" + self.f_log = open(self.log_filename, "wb") + + def _run(self): + cmd_list = ["hivemind/hivemind_cli/p2pd", f"-listen={str(self.control_maddr)}"] + cmd_list += [f"-hostAddrs=/ip4/127.0.0.1/tcp/{find_open_port()}"] + if self.enable_connmgr: + cmd_list += ["-connManager=true", "-connLo=1", "-connHi=2", "-connGrace=0"] + if self.enable_dht: + cmd_list += ["-dht=true"] + if self.enable_pubsub: + cmd_list += ["-pubsub=true", "-pubsubRouter=gossipsub"] + self.proc_daemon = subprocess.Popen( + cmd_list, stdout=self.f_log, stderr=self.f_log, bufsize=0 + ) + + async def wait_until_ready(self): + lines_head_pattern = (b"Control socket:", b"Peer ID:", b"Peer Addrs:") + lines_head_occurred = {line: False for line in lines_head_pattern} + + with open(self.log_filename, "rb") as f_log_read: + + async def read_from_daemon_and_check(): + line = f_log_read.readline() + for head_pattern in lines_head_occurred: + if line.startswith(head_pattern): + lines_head_occurred[head_pattern] = True + return all([value for _, value in lines_head_occurred.items()]) + + await try_until_success(read_from_daemon_and_check) + + # sleep for a while in case that the daemon haven't been ready after emitting these lines + await asyncio.sleep(0.1) + + def close(self): + if self.is_closed: + return + self.proc_daemon.terminate() + self.proc_daemon.wait() + self.f_log.close() + self.is_closed = True + + +class DaemonTuple(NamedTuple): + daemon: Daemon + client: Client + + +class ConnectionFailure(Exception): + pass + + +@asynccontextmanager +async def make_p2pd_pair_unix( + enable_control, enable_connmgr, enable_dht, enable_pubsub +): + name = str(uuid.uuid4())[:8] + control_maddr = Multiaddr(f"/unix/tmp/test_p2pd_control_{name}.sock") + listen_maddr = Multiaddr(f"/unix/tmp/test_p2pd_listen_{name}.sock") + # Remove the existing unix socket files if they are existing + try: + os.unlink(control_maddr.value_for_protocol(protocols.P_UNIX)) + except FileNotFoundError: + pass + try: + os.unlink(listen_maddr.value_for_protocol(protocols.P_UNIX)) + except FileNotFoundError: + pass + async with _make_p2pd_pair( + control_maddr=control_maddr, + listen_maddr=listen_maddr, + enable_control=enable_control, + enable_connmgr=enable_connmgr, + enable_dht=enable_dht, + enable_pubsub=enable_pubsub, + ) as pair: + yield pair + + +@asynccontextmanager +async def make_p2pd_pair_ip4(enable_control, enable_connmgr, enable_dht, enable_pubsub): + control_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}") + listen_maddr = Multiaddr(f"/ip4/127.0.0.1/tcp/{find_open_port()}") + async with _make_p2pd_pair( + control_maddr=control_maddr, + listen_maddr=listen_maddr, + enable_control=enable_control, + enable_connmgr=enable_connmgr, + enable_dht=enable_dht, + enable_pubsub=enable_pubsub, + ) as pair: + yield pair + + +@asynccontextmanager +async def _make_p2pd_pair( + control_maddr, + listen_maddr, + enable_control, + enable_connmgr, + enable_dht, + enable_pubsub, +): + p2pd = Daemon( + control_maddr=control_maddr, + enable_control=enable_control, + enable_connmgr=enable_connmgr, + enable_dht=enable_dht, + enable_pubsub=enable_pubsub, + ) + # wait for daemon ready + await p2pd.wait_until_ready() + client = Client(control_maddr=control_maddr, listen_maddr=listen_maddr) + try: + async with client.listen(): + yield DaemonTuple(daemon=p2pd, client=client) + finally: + if not p2pd.is_closed: + p2pd.close() + + +async def _check_connection(p2pd_tuple_0, p2pd_tuple_1): + peer_id_0, _ = await p2pd_tuple_0.identify() + peer_id_1, _ = await p2pd_tuple_1.identify() + peers_0 = [pinfo.peer_id for pinfo in await p2pd_tuple_0.list_peers()] + peers_1 = [pinfo.peer_id for pinfo in await p2pd_tuple_1.list_peers()] + return (peer_id_0 in peers_1) and (peer_id_1 in peers_0) + + +async def connect_safe(p2pd_tuple_0, p2pd_tuple_1): + peer_id_1, maddrs_1 = await p2pd_tuple_1.identify() + await p2pd_tuple_0.connect(peer_id_1, maddrs_1) + await try_until_success( + functools.partial( + _check_connection, p2pd_tuple_0=p2pd_tuple_0, p2pd_tuple_1=p2pd_tuple_1 + ) + )