Skip to content

Commit

Permalink
Convert AllReduceRunner, Matchmaking, and GroupKeyManager to libp2p b…
Browse files Browse the repository at this point in the history
…ackend
  • Loading branch information
borzunov committed Jul 15, 2021
1 parent c89b598 commit a8fcb0a
Show file tree
Hide file tree
Showing 9 changed files with 126 additions and 239 deletions.
88 changes: 41 additions & 47 deletions hivemind/averaging/allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,14 @@
from typing import Sequence, Dict, Tuple, AsyncIterator, Any, Optional
from enum import Enum

import grpc
import torch

from hivemind.averaging.partition import TensorPartContainer, TensorPartReducer, AllreduceException
from hivemind.utils import Endpoint, get_logger, ChannelCache
from hivemind.p2p import P2P, P2PContext, PeerID as Endpoint, ServicerBase, StubBase
from hivemind.utils import get_logger
from hivemind.utils.asyncio import anext, achain, aiter, aenumerate, amap_in_executor
from hivemind.utils.compression import serialize_torch_tensor, deserialize_torch_tensor
from hivemind.proto import averaging_pb2_grpc, averaging_pb2
from hivemind.proto import averaging_pb2

# flavour types
GroupID = bytes
Expand All @@ -22,7 +22,7 @@ class AveragingMode(Enum):
AUX = 2


class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
class AllReduceRunner(ServicerBase):
"""
An internal class that runs butterfly AllReduce in a predefined group of averagers
Expand All @@ -43,17 +43,20 @@ class AllReduceRunner(averaging_pb2_grpc.DecentralizedAveragingServicer):
def __init__(
self,
*,
p2p: P2P,
group_id: GroupID,
tensors: Sequence[torch.Tensor],
endpoint: Endpoint,
ordered_group_endpoints: Sequence[Endpoint],
peer_fractions: Tuple[float, ...],
weights: Optional[Sequence[float]] = None,
modes: Optional[Sequence[AveragingMode]] = None,
gathered: Optional[Dict[Endpoint, Any]] = None,
**kwargs,
):
assert endpoint in ordered_group_endpoints, "endpoint is not a part of the group"
self._p2p = p2p
self.endpoint = p2p.id
assert self.endpoint in ordered_group_endpoints, "endpoint is not a part of the group"

modes = modes or tuple(AveragingMode.CLIENT if frac == 0 else AveragingMode.NODE for frac in peer_fractions)
weights = weights or tuple(int(mode != AveragingMode.AUX) for mode in modes)
assert len(weights) == len(modes) == len(ordered_group_endpoints), "lists have inconsistent length"
Expand All @@ -62,7 +65,7 @@ def __init__(
assert mode != AveragingMode.CLIENT or frac == 0, "client-mode peer should have zero all-reduce fraction"
assert mode != AveragingMode.AUX or weight == 0, "auxiliary peer should have zero averaging weight"

self.group_id, self.endpoint, self.ordered_group_endpoints = group_id, endpoint, ordered_group_endpoints
self.group_id, self.ordered_group_endpoints = group_id, ordered_group_endpoints
self.modes, self.peer_fractions, self.gathered = modes, peer_fractions, gathered

self._future = asyncio.Future()
Expand Down Expand Up @@ -95,8 +98,10 @@ def __contains__(self, endpoint: Endpoint):
def group_size(self):
return len(self.ordered_group_endpoints)

def _get_peer_stub(self, peer: Endpoint) -> averaging_pb2_grpc.DecentralizedAveragingStub:
return ChannelCache.get_stub(peer, averaging_pb2_grpc.DecentralizedAveragingStub, aio=True)
def _get_stub(self, peer: Endpoint) -> StubBase:
from hivemind.averaging.averager import DecentralizedAverager

return DecentralizedAverager.get_stub(self._p2p, peer)

async def run(self) -> AsyncIterator[torch.Tensor]:
"""Run all-reduce, return differences between averaged and original tensors as they are computed"""
Expand Down Expand Up @@ -136,46 +141,35 @@ async def _communicate_with_peer(self, peer_endpoint: Endpoint):

else:
loop = asyncio.get_event_loop()
stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
write_task = asyncio.create_task(self._write_to_peer(stream, peer_index))

try:
code = None
async for part_index, msg in aenumerate(stream):
if code is None:
code = msg.code
averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)
await write_task

if code != averaging_pb2.AVERAGED_PART:
raise AllreduceException(
f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)} "
f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
f", allreduce failed"
)
finally:
if not write_task.done():
write_task.cancel()

async def _write_to_peer(self, stream: grpc.aio.StreamStreamCall, peer_index: int):
code = None
stream = self._get_stub(peer_endpoint).rpc_aggregate_part(self._generate_input_for_peer(peer_index))
async for part_index, msg in aenumerate(stream):
if code is None:
code = msg.code
averaged_part_delta = await loop.run_in_executor(None, deserialize_torch_tensor, msg.tensor_part)
self.tensor_part_container.register_processed_part(peer_index, part_index, averaged_part_delta)

if code != averaging_pb2.AVERAGED_PART:
raise AllreduceException(
f"peer {peer_endpoint} returned {averaging_pb2.MessageCode.Name(code)} "
f"instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}"
f", allreduce failed"
)

async def _generate_input_for_peer(self, peer_index: int) -> AsyncIterator[averaging_pb2.AveragingData]:
parts_aiter = self.tensor_part_container.iterate_input_parts_for(peer_index)
first_part = await anext(parts_aiter)
await stream.write(
averaging_pb2.AveragingData(
code=averaging_pb2.PART_FOR_AVERAGING,
group_id=self.group_id,
endpoint=self.endpoint,
tensor_part=first_part,
)
yield averaging_pb2.AveragingData(
code=averaging_pb2.PART_FOR_AVERAGING,
group_id=self.group_id,
endpoint=self.endpoint.to_base58(),
tensor_part=first_part,
)
async for part in parts_aiter:
await stream.write(averaging_pb2.AveragingData(tensor_part=part))

await stream.done_writing()
yield averaging_pb2.AveragingData(tensor_part=part)

async def rpc_aggregate_part(
self, stream: AsyncIterator[averaging_pb2.AveragingData], context: grpc.ServicerContext
self, stream: AsyncIterator[averaging_pb2.AveragingData], _: P2PContext
) -> AsyncIterator[averaging_pb2.AveragingData]:
"""a peer sends us a part of his tensor; we should average it with other peers and return the difference"""
request: averaging_pb2.AveragingData = await anext(stream)
Expand All @@ -186,7 +180,7 @@ async def rpc_aggregate_part(

elif request.code == averaging_pb2.PART_FOR_AVERAGING:
try:
sender_index = self.sender_endpoints.index(request.endpoint)
sender_index = self.sender_endpoints.index(Endpoint.from_base58(request.endpoint))
async for msg in self._accumulate_parts_streaming(achain(aiter(request), stream), sender_index):
yield msg

Expand Down Expand Up @@ -224,9 +218,9 @@ async def _accumulate_parts_streaming(self, stream: AsyncIterator[averaging_pb2.
yield averaging_pb2.AveragingData(code=averaging_pb2.AVERAGED_PART, tensor_part=serialized_delta)

async def _send_error_to_peer(self, peer_endpoint: Endpoint, code: averaging_pb2.MessageCode):
stream = self._get_peer_stub(peer_endpoint).rpc_aggregate_part()
await stream.write(averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint, code=code))
await stream.done_writing()
error = averaging_pb2.AveragingData(group_id=self.group_id, endpoint=self.endpoint.to_base58(), code=code)
async for _ in self._get_stub(peer_endpoint).rpc_aggregate_part(aiter(error)):
pass

def finalize(self, *, cancel: bool = False, exception: Optional[BaseException] = None):
"""finish or terminate AllReduceRunner, propagate any errors / cancellations to peers."""
Expand Down
12 changes: 6 additions & 6 deletions hivemind/averaging/averager.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,7 +188,7 @@ def allow_state_sharing(self, value: bool):

@property
def endpoint(self) -> Endpoint:
return self.p2p.id
return self._p2p.id

def run(self):
"""
Expand All @@ -207,14 +207,14 @@ def _run_internal(self):
with ThreadPoolExecutor(max_workers=1) as pipe_awaiter:

async def _run():
self.p2p = await self.dht.replicate_p2p()
self._p2p = await self.dht.replicate_p2p()
if not self.client_mode:
await self.add_p2p_handlers(self.p2p)
await self.add_p2p_handlers(self._p2p)
else:
logger.debug(f"The averager is running in client mode.")

self._matchmaking = Matchmaking(
self.endpoint, self.schema_hash, self.dht, **self.matchmaking_kwargs, client_mode=self.client_mode
self._p2p, self.schema_hash, self.dht, **self.matchmaking_kwargs, client_mode=self.client_mode
)
if not self.client_mode:
asyncio.create_task(self._declare_for_download_periodically())
Expand Down Expand Up @@ -379,9 +379,9 @@ async def _run_allreduce(self, group_info: GroupInfo, min_vector_size: int, **kw

async with self.get_tensors_async() as local_tensors:
allreduce = AllReduceRunner(
p2p=self._p2p,
group_id=group_info.group_id,
tensors=local_tensors,
endpoint=self.endpoint,
ordered_group_endpoints=group_info.endpoints,
peer_fractions=peer_fractions,
weights=weights,
Expand Down Expand Up @@ -551,7 +551,7 @@ async def _load_state_from_peers(self, future: MPFuture):
if peer != self.endpoint:
logger.info(f"Downloading parameters from peer {peer}")
try:
stub = self.get_stub(self.p2p, peer)
stub = self.get_stub(self._p2p, peer)
stream = stub.rpc_download_state(averaging_pb2.DownloadRequest())
current_tensor_parts, tensors = [], []
async for message in stream:
Expand Down
11 changes: 6 additions & 5 deletions hivemind/averaging/key_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,10 @@

import numpy as np

from hivemind.dht import DHT
from hivemind.averaging.group_info import GroupInfo
from hivemind.utils import get_logger, Endpoint, DHTExpiration, get_dht_time, ValueWithExpiration
from hivemind.dht import DHT
from hivemind.p2p import PeerID as Endpoint
from hivemind.utils import get_logger, DHTExpiration, get_dht_time, ValueWithExpiration

GroupKey = str
GROUP_PATTERN = re.compile("^(([^.])+)[.]0b[01]*$") # e.g. bert_exp4_averaging.0b01001101
Expand Down Expand Up @@ -72,7 +73,7 @@ async def declare_averager(
expiration_time = expiration_time if looking_for_group else float(np.nextafter(expiration_time, float("inf")))
return await self.dht.store(
key=group_key,
subkey=endpoint,
subkey=endpoint.to_base58(),
value=looking_for_group,
expiration_time=expiration_time,
return_future=True,
Expand All @@ -93,11 +94,11 @@ async def get_averagers(self, group_key: GroupKey, only_active: bool) -> List[Tu
logger.debug(f"Allreduce group not found: {group_key}, creating new group.")
return []
averagers = [
(key, entry.expiration_time)
(Endpoint.from_base58(key), entry.expiration_time)
for key, entry in result.value.items()
if key != self.RESERVED_KEY_FOR_NBITS and (not only_active or entry.value is True)
]
num_active_averagers = len([key for key, entry in result.value.items() if entry.value is True])
num_active_averagers = sum(1 for entry in result.value.values() if entry.value is True)

suggested_nbits = self.get_suggested_nbits(result)
if (
Expand Down
Loading

0 comments on commit a8fcb0a

Please sign in to comment.