From 9d6b36d3b85136714a2b75bf88b32105e1819b36 Mon Sep 17 00:00:00 2001 From: Alexey Bukhtiyarov Date: Sun, 21 Mar 2021 17:46:49 +0300 Subject: [PATCH 1/2] Fix scalar deserialization --- hivemind/utils/grpc.py | 14 +++++++++++--- tests/test_util_modules.py | 9 +++++++++ 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/hivemind/utils/grpc.py b/hivemind/utils/grpc.py index fd2bd9cff..ea4a0a1d8 100644 --- a/hivemind/utils/grpc.py +++ b/hivemind/utils/grpc.py @@ -215,9 +215,17 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor: # TODO avoid copying the array (need to silence pytorch warning, because array is not writable) + + def construct_torch_tensor(serialized_tensor, array, dtype=None): + """ Helper conversion function that handles some edge case """ + if serialized_tensor.size: + return torch.as_tensor(array, dtype=dtype).view(*serialized_tensor.size) + else: + return torch.as_tensor(array, dtype=dtype) + if serialized_tensor.compression == CompressionType.NONE: array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype)).copy() - tensor = torch.as_tensor(array).view(*serialized_tensor.size) + tensor = construct_torch_tensor(serialized_tensor, array) elif serialized_tensor.compression == CompressionType.MEANSTD_LAST_AXIS_FLOAT16: stats_size = list(serialized_tensor.size) stats_size[-1] = 1 @@ -227,10 +235,10 @@ def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Ten means = torch.as_tensor(np.frombuffer(means, dtype=np.float32).copy()).view(*stats_size) stds = torch.as_tensor(np.frombuffer(stds, dtype=np.float32).copy()).view(*stats_size) array = np.frombuffer(serialized_tensor.buffer[:-8 * stats_count], dtype=np.float16).copy() - tensor = torch.as_tensor(array, dtype=torch.float32).view(*serialized_tensor.size).mul_(stds).add_(means) + tensor = construct_torch_tensor(serialized_tensor, array, torch.float32).mul_(stds).add_(means) elif serialized_tensor.compression == CompressionType.FLOAT16: array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16).copy() - tensor = torch.as_tensor(array, dtype=torch.float32).view(*serialized_tensor.size) + tensor = construct_torch_tensor(serialized_tensor, array, torch.float32) else: raise ValueError(f"Unknown compression type: {serialized_tensor.compression}") diff --git a/tests/test_util_modules.py b/tests/test_util_modules.py index 7c0c17b56..5b8858586 100644 --- a/tests/test_util_modules.py +++ b/tests/test_util_modules.py @@ -121,6 +121,7 @@ async def wait_and_raise(): await future + def test_vector_compression(size=(128, 128, 64), alpha=5e-08): torch.manual_seed(0) from hivemind.proto.runtime_pb2 import CompressionType @@ -194,6 +195,14 @@ def test_serialize_tensor(): restored = hivemind.combine_from_streaming(chunks) assert torch.allclose(hivemind.deserialize_torch_tensor(restored), tensor) + scalar = torch.tensor(1.) + serialized_scalar = hivemind.serialize_torch_tensor(scalar, hivemind.CompressionType.NONE) + assert torch.allclose(hivemind.deserialize_torch_tensor(serialized_scalar), scalar) + + serialized_scalar = hivemind.serialize_torch_tensor(scalar, hivemind.CompressionType.FLOAT16) + assert torch.allclose(hivemind.deserialize_torch_tensor(serialized_scalar), scalar) + + def test_serialize_tuple(): test_pairs = ( From d69a9a8e87aae242d1d12c9dbd5f90fe99721adc Mon Sep 17 00:00:00 2001 From: Alexey Bukhtiyarov Date: Sun, 21 Mar 2021 20:08:00 +0300 Subject: [PATCH 2/2] Fix scalar deserialization --- hivemind/utils/grpc.py | 24 ++++++++++++------------ 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/hivemind/utils/grpc.py b/hivemind/utils/grpc.py index ea4a0a1d8..e79614398 100644 --- a/hivemind/utils/grpc.py +++ b/hivemind/utils/grpc.py @@ -5,7 +5,7 @@ import os import threading -from typing import NamedTuple, Tuple, Optional, Union, Any, Dict, TypeVar, Type, Iterator, Iterable +from typing import NamedTuple, Tuple, Optional, Union, Any, Dict, TypeVar, Type, Iterator, Iterable, Sequence import grpc import numpy as np @@ -213,19 +213,19 @@ def serialize_torch_tensor(tensor: torch.Tensor, compression_type=CompressionTyp return proto -def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor: - # TODO avoid copying the array (need to silence pytorch warning, because array is not writable) +def construct_torch_tensor(array: np.ndarray, size: Sequence, dtype: Optional[torch.dtype]=None): + """ Helper conversion function that handles edge case with scalar deserialization """ + if size: + return torch.as_tensor(array, dtype=dtype).view(*size) + else: + return torch.as_tensor(array, dtype=dtype) - def construct_torch_tensor(serialized_tensor, array, dtype=None): - """ Helper conversion function that handles some edge case """ - if serialized_tensor.size: - return torch.as_tensor(array, dtype=dtype).view(*serialized_tensor.size) - else: - return torch.as_tensor(array, dtype=dtype) +def deserialize_torch_tensor(serialized_tensor: runtime_pb2.Tensor) -> torch.Tensor: + # TODO avoid copying the array (need to silence pytorch warning, because array is not writable) if serialized_tensor.compression == CompressionType.NONE: array = np.frombuffer(serialized_tensor.buffer, dtype=np.dtype(serialized_tensor.dtype)).copy() - tensor = construct_torch_tensor(serialized_tensor, array) + tensor = construct_torch_tensor(array, serialized_tensor.size) elif serialized_tensor.compression == CompressionType.MEANSTD_LAST_AXIS_FLOAT16: stats_size = list(serialized_tensor.size) stats_size[-1] = 1 @@ -235,10 +235,10 @@ def construct_torch_tensor(serialized_tensor, array, dtype=None): means = torch.as_tensor(np.frombuffer(means, dtype=np.float32).copy()).view(*stats_size) stds = torch.as_tensor(np.frombuffer(stds, dtype=np.float32).copy()).view(*stats_size) array = np.frombuffer(serialized_tensor.buffer[:-8 * stats_count], dtype=np.float16).copy() - tensor = construct_torch_tensor(serialized_tensor, array, torch.float32).mul_(stds).add_(means) + tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32).mul_(stds).add_(means) elif serialized_tensor.compression == CompressionType.FLOAT16: array = np.frombuffer(serialized_tensor.buffer, dtype=np.float16).copy() - tensor = construct_torch_tensor(serialized_tensor, array, torch.float32) + tensor = construct_torch_tensor(array, serialized_tensor.size, torch.float32) else: raise ValueError(f"Unknown compression type: {serialized_tensor.compression}")