From 20f19b122fdfc887317961b7c396103f82b9fbde Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Fri, 16 Jul 2021 21:00:29 +0300 Subject: [PATCH] Fix RPC in ServicerBase derivatives for test_training_averager --- hivemind/averaging/allreduce.py | 13 ++++++++++--- hivemind/averaging/averager.py | 8 +++++++- hivemind/averaging/matchmaking.py | 8 ++++---- hivemind/dht/__init__.py | 2 +- tests/test_allreduce.py | 7 ++----- 5 files changed, 24 insertions(+), 14 deletions(-) diff --git a/hivemind/averaging/allreduce.py b/hivemind/averaging/allreduce.py index 16bf6d605..82c08e511 100644 --- a/hivemind/averaging/allreduce.py +++ b/hivemind/averaging/allreduce.py @@ -30,6 +30,10 @@ class AllReduceRunner(ServicerBase): creating a full DecentralizedAverager. :note: this class returns **differences** between averaged and local tensors in order to improve numerical stability + :param p2p: a hivemind.p2p.P2P instance used for communication with other peers + :param servicer: a hivemind.p2p.ServicerBase instance whose RPC signatures are used when requesting other peers. + Typically, it is a DecentralizedAverager instance or its derivative. + If None, uses ``self`` for this purpose (since this class may be a servicer itself for testing purposes). :param group_id: unique identifier of this specific all-reduce run :param tensors: local tensors that should be averaged with groupmates :param tensors: local tensors that should be averaged with groupmates @@ -47,6 +51,7 @@ def __init__( self, *, p2p: P2P, + servicer: Optional[ServicerBase], group_id: GroupID, tensors: Sequence[torch.Tensor], ordered_group_endpoints: Sequence[Endpoint], @@ -60,6 +65,10 @@ def __init__( self.endpoint = p2p.id assert self.endpoint in ordered_group_endpoints, "endpoint is not a part of the group" + if servicer is None: + servicer = self + self._servicer = servicer + modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions) weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes) assert len(weights) == len(modes) == len(ordered_group_endpoints), "lists have inconsistent length" @@ -102,9 +111,7 @@ def group_size(self): return len(self.ordered_group_endpoints) def _get_stub(self, peer: Endpoint) -> StubBase: - from hivemind.averaging.averager import DecentralizedAverager - - return DecentralizedAverager.get_stub(self._p2p, peer) + return self._servicer.get_stub(self._p2p, peer) async def run(self) -> AsyncIterator[torch.Tensor]: """Run all-reduce, return differences between averaged and original tensors as they are computed""" diff --git a/hivemind/averaging/averager.py b/hivemind/averaging/averager.py index 0056386ee..7c2253efa 100644 --- a/hivemind/averaging/averager.py +++ b/hivemind/averaging/averager.py @@ -214,7 +214,12 @@ async def _run(): logger.debug(f"The averager is running in client mode.") self._matchmaking = Matchmaking( - self._p2p, self.schema_hash, self.dht, **self.matchmaking_kwargs, client_mode=self.client_mode + self._p2p, + self, + self.schema_hash, + self.dht, + client_mode=self.client_mode, + **self.matchmaking_kwargs, ) if not self.client_mode: asyncio.create_task(self._declare_for_download_periodically()) @@ -378,6 +383,7 @@ async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kw async with self.get_tensors_async() as local_tensors: allreduce = AllReduceRunner( p2p=self._p2p, + servicer=self, group_id=group_info.group_id, tensors=local_tensors, ordered_group_endpoints=group_info.endpoints, diff --git a/hivemind/averaging/matchmaking.py b/hivemind/averaging/matchmaking.py index 30e1bc3d1..3e9bb52a1 100644 --- a/hivemind/averaging/matchmaking.py +++ b/hivemind/averaging/matchmaking.py @@ -13,7 +13,7 @@ from hivemind.averaging.group_info import GroupInfo from hivemind.averaging.key_manager import GroupKeyManager, GroupKey from hivemind.dht import DHT, DHTID, DHTExpiration -from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID as Endpoint +from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID as Endpoint, ServicerBase from hivemind.utils import get_logger, timed_storage, TimedStorage, get_dht_time from hivemind.utils.asyncio import anext from hivemind.proto import averaging_pb2 @@ -37,6 +37,7 @@ class Matchmaking: def __init__( self, p2p: P2P, + servicer: ServicerBase, schema_hash: bytes, dht: DHT, *, @@ -57,6 +58,7 @@ def __init__( super().__init__() self._p2p = p2p + self._servicer = servicer self.endpoint = p2p.id self.schema_hash = schema_hash self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size) @@ -173,9 +175,7 @@ async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpirat stream: AsyncIterator[averaging_pb2.MessageFromLeader] = None try: async with self.lock_request_join_group: - from hivemind.averaging.averager import DecentralizedAverager - - leader_stub = DecentralizedAverager.get_stub(self._p2p, leader) + leader_stub = self._servicer.get_stub(self._p2p, leader) stream = leader_stub.rpc_join_group( averaging_pb2.JoinRequest( diff --git a/hivemind/dht/__init__.py b/hivemind/dht/__init__.py index 79c6bcc50..aa2edf3f1 100644 --- a/hivemind/dht/__init__.py +++ b/hivemind/dht/__init__.py @@ -24,9 +24,9 @@ from multiaddr import Multiaddr from hivemind.dht.node import DHTNode -from hivemind.p2p import P2P, PeerID from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey from hivemind.dht.validation import CompositeValidator, RecordValidatorBase +from hivemind.p2p import P2P, PeerID from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop logger = get_logger(__name__) diff --git a/tests/test_allreduce.py b/tests/test_allreduce.py index 8c3c06810..f3254882c 100644 --- a/tests/test_allreduce.py +++ b/tests/test_allreduce.py @@ -176,10 +176,6 @@ async def send_tensors(sender_index: int): async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions, part_size_bytes): """Run group allreduce protocol manually without grpc, see if the internal logic is working as intended""" - class AllreduceRunnerForTesting(AllReduceRunner): - def _get_stub(self, peer: str) -> StubBase: - return AllreduceRunnerForTesting.get_stub(self._p2p, peer) - p2ps = [await P2P.create()] visible_maddrs = await p2ps[0].get_visible_maddrs() p2ps += await asyncio.gather(*[P2P.create(initial_peers=visible_maddrs) for _ in range(3)]) @@ -194,8 +190,9 @@ def _get_stub(self, peer: str) -> StubBase: allreduce_protocols = [] for p2p in p2ps: - allreduce_protocol = AllreduceRunnerForTesting( + allreduce_protocol = AllReduceRunner( p2p=p2p, + servicer=AllReduceRunner, group_id=group_id, tensors=[x.clone() for x in tensors_by_peer[p2p.id]], ordered_group_endpoints=peers,