diff --git a/hivemind/client/averaging/__init__.py b/hivemind/client/averaging/__init__.py index 2d8767a77..fee6fb95a 100644 --- a/hivemind/client/averaging/__init__.py +++ b/hivemind/client/averaging/__init__.py @@ -12,6 +12,7 @@ from typing import Sequence, Optional, Tuple, Any, Union, Dict, AsyncIterator import grpc +from grpc._cython.cygrpc import InternalError import torch import numpy as np @@ -239,10 +240,13 @@ async def _step(self, *, future: MPFuture, gather_binary: bytes, allow_retries: gathered_data_by_peer = dict(zip(allreduce_group.ordered_group_endpoints, gathered_items)) future.set_result(gathered_data_by_peer) - except (AllreduceException, MatchmakingException): + except (AllreduceException, MatchmakingException, asyncio.exceptions.InvalidStateError, + grpc.RpcError, grpc.aio.AioRpcError, InternalError) as e: time_elapsed = get_dht_time() - start_time if not allow_retries or (timeout is not None and timeout < time_elapsed): future.set_result(None) + else: + logger.debug(f"caught {e}, retrying") except Exception as e: future.set_exception(e) @@ -311,9 +315,9 @@ async def rpc_download_state(self, request: averaging_pb2.DownloadRequest, conte ) -> AsyncIterator[averaging_pb2.DownloadData]: """ Get the up-to-date trainer state from a peer. - The state consists of two parts: (metadata, tensors) + The state consists of two parts: (serialized_metadata, tensors) - - metadata is a small serialized bytestring meant to store scalars and hyperparameters + - serialized_metadata is a small serialized bytestring meant to store scalars and hyperparameters - tensors is a sequence of pytorch tensors that represent model parameters or optimizer statistics """ chunk_size_bytes = self.matchmaking_kwargs.get('chunk_size_bytes', DEFAULT_CHUNK_SIZE_BYTES) @@ -342,15 +346,15 @@ async def _get_current_state_from_host_process(self): self._pipe.send(('_TRIGGER_GET_CURRENT_STATE', _future)) return await future - def load_state_from_peers(self, wait=True) -> Optional[Tuple[bytes, Sequence[torch.Tensor]]]: + def load_state_from_peers(self, wait=True) -> Optional[Tuple[Any, Sequence[torch.Tensor]]]: """ Try to download the latest optimizer state one of the existing peer. - :returns: on success, return a 2-tuple with (serialized_metadata, tensors), where + :returns: on success, return a 2-tuple with (metadata, tensors), where - - serialized_metadata is a small bytestring containing **serialized** metadata (e.g. hyperparameters) + - metadata is a small object containing metadata (e.g. hyperparameters, scalars, etc) - tensors is a sequence of pytorch tensors meant to contain peer's model weights and optimizer statistics - The exact contents of both serialized_metadata and tensors are determined by get_current_state method + The exact contents of both metadata and tensors are determined by get_current_state method """ future, _future = MPFuture.make_pair() self.pipe.send(('_load_state_from_peers', [], dict(future=_future))) @@ -441,7 +445,7 @@ def _background_thread_fetch_current_state(serializer: SerializerBase, pipe: mp. Executed in the host process as a background thread. Fetches the averager state when asked by peers. :param serializer: a serializer with which to convert metadata into bytes :param pipe: DecentralizedAverager's control pipe (from host process side) - :param get_current_state_ref: a WeakMethod wrapped around DecentraliedAverager.get_current_state (instance-bound) + :param get_current_state_ref: a WeakMethod wrapped around DecentralizedAverager.get_current_state (instance-bound) """ while True: trigger, future = pipe.recv() diff --git a/hivemind/client/averaging/allreduce.py b/hivemind/client/averaging/allreduce.py index eda571195..53b79f51c 100644 --- a/hivemind/client/averaging/allreduce.py +++ b/hivemind/client/averaging/allreduce.py @@ -153,7 +153,12 @@ async def _communicate_with_peer(self, peer_endpoint: Endpoint, local_part: torc f" instead of {averaging_pb2.MessageCode.Name(averaging_pb2.AVERAGED_PART)}," f" allreduce failed") - averaged_part = deserialize_torch_tensor(combine_from_streaming([message.tensor_part for message in outputs])) + try: + averaged_part = deserialize_torch_tensor(combine_from_streaming( + [message.tensor_part for message in outputs])) + except RuntimeError as e: + raise AllreduceException(f"Could not deserialize averaged part from {peer_endpoint}: {e}") + self.register_averaged_part(peer_endpoint, averaged_part) return averaged_part @@ -182,7 +187,11 @@ async def run(self) -> Sequence[torch.Tensor]: async def accumulate_part_streaming(self, source: Endpoint, stream_messages: Iterable[runtime_pb2.Tensor] ) -> Iterable[runtime_pb2.Tensor]: """ accumulate_part using streams of serialized tensors. Used to prevent duplicate work in serialization """ - tensor_part: torch.Tensor = deserialize_torch_tensor(combine_from_streaming(stream_messages)) + try: + tensor_part = deserialize_torch_tensor(combine_from_streaming(stream_messages)) + except RuntimeError as e: + raise AllreduceException(f"Could not deserialize tensor part from {source} for streaming {e}") + averaged_part = await self.accumulate_part(source, tensor_part) if not self.averaged_part_stream.done(): serialized_tensor = serialize_torch_tensor(averaged_part, self.compression_type, allow_inplace=False) diff --git a/tests/test_util_modules.py b/tests/test_util_modules.py index e730070e1..1981530da 100644 --- a/tests/test_util_modules.py +++ b/tests/test_util_modules.py @@ -1,5 +1,6 @@ import asyncio import torch +import numpy as np import pytest import hivemind @@ -57,7 +58,7 @@ def test_mpfuture_cancel(): with pytest.raises(RuntimeError): future.set_result(123) with pytest.raises(RuntimeError): - future.set_exception(NotImplementedError) + future.set_exception(NotImplementedError()) assert future.cancelled() and future.done() and not future.running() @@ -192,6 +193,42 @@ def test_serialize_tensor(): restored = hivemind.combine_from_streaming(chunks) assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor) + +def test_split_parts(): + tensor = torch.randn(910, 512) + serialized_tensor_part = hivemind.utils.serialize_torch_tensor(tensor, allow_inplace=False) + chunks1 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 16384)) + assert len(chunks1) == int(np.ceil(tensor.numel() * tensor.element_size() / 16384)) + + chunks2 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10_000)) + assert len(chunks2) == int(np.ceil(tensor.numel() * tensor.element_size() / 10_000)) + + chunks3 = list(hivemind.utils.split_for_streaming(serialized_tensor_part, 10 ** 9)) + assert len(chunks3) == 1 + + compressed_tensor_part = hivemind.utils.serialize_torch_tensor(tensor, hivemind.CompressionType.FLOAT16, + allow_inplace=False) + chunks4 = list(hivemind.utils.split_for_streaming(compressed_tensor_part, 16384)) + assert len(chunks4) == int(np.ceil(tensor.numel() * 2 / 16384)) + + combined1 = hivemind.utils.combine_from_streaming(chunks1) + combined2 = hivemind.utils.combine_from_streaming(iter(chunks2)) + combined3 = hivemind.utils.combine_from_streaming(chunks3) + combined4 = hivemind.utils.combine_from_streaming(chunks4) + for combined in combined1, combined2, combined3: + assert torch.allclose(tensor, hivemind.deserialize_torch_tensor(combined), rtol=1e-5, atol=1e-8) + + assert torch.allclose(tensor, hivemind.deserialize_torch_tensor(combined4), rtol=1e-3, atol=1e-3) + + combined_incomplete = hivemind.utils.combine_from_streaming(chunks4[:5]) + combined_incomplete2 = hivemind.utils.combine_from_streaming(chunks4[:1]) + combined_incomplete3 = hivemind.utils.combine_from_streaming(chunks4[:-1]) + for combined in combined_incomplete, combined_incomplete2, combined_incomplete3: + with pytest.raises(RuntimeError): + hivemind.deserialize_torch_tensor(combined) + # note: we rely on this being RuntimeError in hivemind.client.averager.allreduce.AllreduceProtocol + + def test_generic_data_classes(): from hivemind.utils import ValueWithExpiration, HeapEntry, DHTExpiration