Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Handle edge cases in DecentralizedAverager #171

Merged
merged 16 commits into from
Mar 4, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 12 additions & 8 deletions hivemind/client/averaging/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator

import grpc
from grpc._cython.cygrpc import InternalError
import torch
import numpy as np

Expand Down Expand Up @@ -239,10 +240,13 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, allow_retries:
gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items))
future.set_result(gathered_data_by_peer)

except (AllreduceException, MatchmakingException):
except (AllreduceException, MatchmakingException, asyncio.exceptions.InvalidStateError,
grpc.RpcError, grpc.aio.AioRpcError, InternalError) as e:
time_elapsed = get_dht_time() - start_time
if not allow_retries or (timeout is not None and timeout < time_elapsed):
future.set_result(None)
else:
logger.debug(f"caught {e}, retrying")

except Exception as e:
future.set_exception(e)
Expand Down Expand Up @@ -311,9 +315,9 @@ async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, conte
) -> AsyncIterator[averaging_pb2.DownloadData]:
"""
Get the up-to-date trainer state from a peer.
The state consists of two parts: (metadata, tensors)
The state consists of two parts: (serialized_metadata, tensors)

- metadata is a small serialized bytestring meant to store scalars and hyperparameters
- serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters
- tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics
"""
chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES)
Expand Down Expand Up @@ -342,15 +346,15 @@ async def _get_current_state_from_host_process(self):
self._pipe.send(('_TRIGGER_GET_CURRENT_STATE', _future))
return await future

def load_state_from_peers(self, wait=True) -> Optional[Tuple[bytes, Sequence[torch.Tensor]]]:
def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]:
"""
Try to download the latest optimizer state one of the existing peer.
:returns: on success, return a 2-tuple with (serialized_metadata, tensors), where
:returns: on success, return a 2-tuple with (metadata, tensors), where

- serialized_metadata is a small bytestring containing **serialized** metadata (e.g. hyperparameters)
- metadata is a small object containing metadata (e.g. hyperparameters, scalars, etc)
- tensors is a sequence of pytorch tensors meant to contain peer's model weights and optimizer statistics

The exact contents of both serialized_metadata and tensors are determined by get_current_state method
The exact contents of both metadata and tensors are determined by get_current_state method
"""
future, _future = MPFuture.make_pair()
self.pipe.send(('_load_state_from_peers', [], dict(future=_future)))
Expand Down Expand Up @@ -441,7 +445,7 @@ def _background_thread_fetch_current_state(serializer: SerializerBase, pipe: mp.
Executed in the host process as a background thread. Fetches the averager state when asked by peers.
:param serializer: a serializer with which to convert metadata into bytes
:param pipe: DecentralizedAverager's control pipe (from host process side)
:param get_current_state_ref: a WeakMethod wrapped around DecentraliedAverager.get_current_state (instance-bound)
:param get_current_state_ref: a WeakMethod wrapped around DecentralizedAverager.get_current_state (instance-bound)
"""
while True:
trigger, future = pipe.recv()
Expand Down
13 changes: 11 additions & 2 deletions hivemind/client/averaging/allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,12 @@ async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torc
f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)},"
f" allreduce failed")

averaged_part = deserialize_torch_tensor(combine_from_streaming([message.tensor_part for message in outputs]))
try:
averaged_part = deserialize_torch_tensor(combine_from_streaming(
[message.tensor_part for message in outputs]))
except RuntimeError as e:
raise AllreduceException(f"Could not deserialize averaged part from {peer_endpoint}: {e}")

self.register_averaged_part(peer_endpoint, averaged_part)
return averaged_part

Expand Down Expand Up @@ -182,7 +187,11 @@ async def run(self) -> Sequence[torch.Tensor]:
async def accumulate_part_streaming(self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor]
) -> Iterable[runtime_pb2.Tensor]:
""" accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """
tensor_part: torch.Tensor = deserialize_torch_tensor(combine_from_streaming(stream_messages))
try:
tensor_part = deserialize_torch_tensor(combine_from_streaming(stream_messages))
except RuntimeError as e:
raise AllreduceException(f"Could not deserialize tensor part from {source} for streaming {e}")
Comment on lines +192 to +193
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we should narrow down the scope of this exception:

  1. Do we want to catch errors that happen only because of streaming? If so, enclose combine_from_streaming inside the try-except block
  2. Do we want to catch any kind of specifically deserialization errors? If so, move this check inside deserialize_torch_tensor. If you still want an AllreduceException here, you can then catch a specific SerializerException

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The problem is that errors in combine_from_streaming only manifest in deserialize_torch_tensor .


averaged_part = await self.accumulate_part(source, tensor_part)
if not self.averaged_part_stream.done():
serialized_tensor = serialize_torch_tensor(averaged_part, self.compression_type, allow_inplace=False)
Expand Down
39 changes: 38 additions & 1 deletion tests/test_util_modules.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import torch
import numpy as np

import pytest
import hivemind
Expand Down Expand Up @@ -57,7 +58,7 @@ def test_mpfuture_cancel():
with pytest.raises(RuntimeError):
future.set_result(123)
with pytest.raises(RuntimeError):
future.set_exception(NotImplementedError)
future.set_exception(NotImplementedError())
assert future.cancelled() and future.done() and not future.running()


Expand Down Expand Up @@ -192,6 +193,42 @@ def test_serialize_tensor():
restored = hivemind.combine_from_streaming(chunks)
assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor)


def test_split_parts():
tensor = torch.randn(910, 512)
serialized_tensor_part = hivemind.utils.serialize_torch_tensor(tensor, allow_inplace=False)
chunks1 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 16384))
assert len(chunks1) == int(np.ceil(tensor.numel() * tensor.element_size() / 16384))

chunks2 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10_000))
assert len(chunks2) == int(np.ceil(tensor.numel() * tensor.element_size() / 10_000))

chunks3 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10 ** 9))
assert len(chunks3) == 1

compressed_tensor_part = hivemind.utils.serialize_torch_tensor(tensor, hivemind.CompressionType.FLOAT16,
allow_inplace=False)
chunks4 = list(hivemind.utils.split_for_streaming(compressed_tensor_part, 16384))
assert len(chunks4) == int(np.ceil(tensor.numel() * 2 / 16384))

combined1 = hivemind.utils.combine_from_streaming(chunks1)
combined2 = hivemind.utils.combine_from_streaming(iter(chunks2))
combined3 = hivemind.utils.combine_from_streaming(chunks3)
combined4 = hivemind.utils.combine_from_streaming(chunks4)
for combined in combined1, combined2, combined3:
assert torch.allclose(tensor, hivemind.deserialize_torch_tensor(combined), rtol=1e-5, atol=1e-8)

assert torch.allclose(tensor, hivemind.deserialize_torch_tensor(combined4), rtol=1e-3, atol=1e-3)

combined_incomplete = hivemind.utils.combine_from_streaming(chunks4[:5])
combined_incomplete2 = hivemind.utils.combine_from_streaming(chunks4[:1])
combined_incomplete3 = hivemind.utils.combine_from_streaming(chunks4[:-1])
for combined in combined_incomplete, combined_incomplete2, combined_incomplete3:
with pytest.raises(RuntimeError):
hivemind.deserialize_torch_tensor(combined)
# note: we rely on this being RuntimeError in hivemind.client.averager.allreduce.AllreduceProtocol


def test_generic_data_classes():
from hivemind.utils import ValueWithExpiration, HeapEntry, DHTExpiration

Expand Down