Skip to content

Commit

Permalink
Fix RPC in ServicerBase derivatives for test_training_averager
Browse files Browse the repository at this point in the history
  • Loading branch information
borzunov committed Jul 16, 2021
1 parent 12e8039 commit 20f19b1
Show file tree
Hide file tree
Showing 5 changed files with 24 additions and 14 deletions.
13 changes: 10 additions & 3 deletions hivemind/averaging/allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ class AllReduceRunner(ServicerBase):
creating a full DecentralizedAverager.
:note: this class returns **differences** between averaged and local tensors in order to improve numerical stability
:param p2p: a hivemind.p2p.P2P instance used for communication with other peers
:param servicer: a hivemind.p2p.ServicerBase instance whose RPC signatures are used when requesting other peers.
Typically, it is a DecentralizedAverager instance or its derivative.
If None, uses ``self`` for this purpose (since this class may be a servicer itself for testing purposes).
:param group_id: unique identifier of this specific all-reduce run
:param tensors: local tensors that should be averaged with groupmates
:param tensors: local tensors that should be averaged with groupmates
Expand All @@ -47,6 +51,7 @@ def __init__(
self,
*,
p2p: P2P,
servicer: Optional[ServicerBase],
group_id: GroupID,
tensors: Sequence[torch.Tensor],
ordered_group_endpoints: Sequence[Endpoint],
Expand All @@ -60,6 +65,10 @@ def __init__(
self.endpoint = p2p.id
assert self.endpoint in ordered_group_endpoints, "endpoint is not a part of the group"

if servicer is None:
servicer = self
self._servicer = servicer

modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
assert len(weights) == len(modes) == len(ordered_group_endpoints), "lists have inconsistent length"
Expand Down Expand Up @@ -102,9 +111,7 @@ def group_size(self):
return len(self.ordered_group_endpoints)

def _get_stub(self, peer: Endpoint) -> StubBase:
from hivemind.averaging.averager import DecentralizedAverager

return DecentralizedAverager.get_stub(self._p2p, peer)
return self._servicer.get_stub(self._p2p, peer)

async def run(self) -> AsyncIterator[torch.Tensor]:
"""Run all-reduce, return differences between averaged and original tensors as they are computed"""
Expand Down
8 changes: 7 additions & 1 deletion hivemind/averaging/averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,12 @@ async def _run():
logger.debug(f"The averager is running in client mode.")

self._matchmaking = Matchmaking(
self._p2p, self.schema_hash, self.dht, **self.matchmaking_kwargs, client_mode=self.client_mode
self._p2p,
self,
self.schema_hash,
self.dht,
client_mode=self.client_mode,
**self.matchmaking_kwargs,
)
if not self.client_mode:
asyncio.create_task(self._declare_for_download_periodically())
Expand Down Expand Up @@ -378,6 +383,7 @@ async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kw
async with self.get_tensors_async() as local_tensors:
allreduce = AllReduceRunner(
p2p=self._p2p,
servicer=self,
group_id=group_info.group_id,
tensors=local_tensors,
ordered_group_endpoints=group_info.endpoints,
Expand Down
8 changes: 4 additions & 4 deletions hivemind/averaging/matchmaking.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from hivemind.averaging.group_info import GroupInfo
from hivemind.averaging.key_manager import GroupKeyManager, GroupKey
from hivemind.dht import DHT, DHTID, DHTExpiration
from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID as Endpoint
from hivemind.p2p import P2P, P2PContext, P2PHandlerError, PeerID as Endpoint, ServicerBase
from hivemind.utils import get_logger, timed_storage, TimedStorage, get_dht_time
from hivemind.utils.asyncio import anext
from hivemind.proto import averaging_pb2
Expand All @@ -37,6 +37,7 @@ class Matchmaking:
def __init__(
self,
p2p: P2P,
servicer: ServicerBase,
schema_hash: bytes,
dht: DHT,
*,
Expand All @@ -57,6 +58,7 @@ def __init__(

super().__init__()
self._p2p = p2p
self._servicer = servicer
self.endpoint = p2p.id
self.schema_hash = schema_hash
self.group_key_manager = GroupKeyManager(dht, prefix, initial_group_bits, target_group_size)
Expand Down Expand Up @@ -173,9 +175,7 @@ async def request_join_group(self, leader: Endpoint, expiration_time: DHTExpirat
stream: AsyncIterator[averaging_pb2.MessageFromLeader] = None
try:
async with self.lock_request_join_group:
from hivemind.averaging.averager import DecentralizedAverager

leader_stub = DecentralizedAverager.get_stub(self._p2p, leader)
leader_stub = self._servicer.get_stub(self._p2p, leader)

stream = leader_stub.rpc_join_group(
averaging_pb2.JoinRequest(
Expand Down
2 changes: 1 addition & 1 deletion hivemind/dht/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,9 @@
from multiaddr import Multiaddr

from hivemind.dht.node import DHTNode
from hivemind.p2p import P2P, PeerID
from hivemind.dht.routing import DHTID, DHTKey, DHTValue, Subkey
from hivemind.dht.validation import CompositeValidator, RecordValidatorBase
from hivemind.p2p import P2P, PeerID
from hivemind.utils import DHTExpiration, MPFuture, ValueWithExpiration, await_cancelled, get_logger, switch_to_uvloop

logger = get_logger(__name__)
Expand Down
7 changes: 2 additions & 5 deletions tests/test_allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,10 +176,6 @@ async def send_tensors(sender_index: int):
async def test_allreduce_protocol(peer_modes, averaging_weights, peer_fractions, part_size_bytes):
"""Run group allreduce protocol manually without grpc, see if the internal logic is working as intended"""

class AllreduceRunnerForTesting(AllReduceRunner):
def _get_stub(self, peer: str) -> StubBase:
return AllreduceRunnerForTesting.get_stub(self._p2p, peer)

p2ps = [await P2P.create()]
visible_maddrs = await p2ps[0].get_visible_maddrs()
p2ps += await asyncio.gather(*[P2P.create(initial_peers=visible_maddrs) for _ in range(3)])
Expand All @@ -194,8 +190,9 @@ def _get_stub(self, peer: str) -> StubBase:

allreduce_protocols = []
for p2p in p2ps:
allreduce_protocol = AllreduceRunnerForTesting(
allreduce_protocol = AllReduceRunner(
p2p=p2p,
servicer=AllReduceRunner,
group_id=group_id,
tensors=[x.clone() for x in tensors_by_peer[p2p.id]],
ordered_group_endpoints=peers,
Expand Down

0 comments on commit 20f19b1

Please sign in to comment.