From 2262925fa663c4b0e18011909cd2d73f9cf56967 Mon Sep 17 00:00:00 2001 From: Aleksandr Borzunov Date: Thu, 20 Jul 2023 20:07:23 +0000 Subject: [PATCH] Fix .lstrip() --- hivemind/compression/base.py | 2 +- hivemind/compression/quantization.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/hivemind/compression/base.py b/hivemind/compression/base.py index 630603546..36e3fc3f8 100644 --- a/hivemind/compression/base.py +++ b/hivemind/compression/base.py @@ -84,7 +84,7 @@ class NoCompression(CompressionBase): def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor: tensor = tensor.detach() shape = tensor.shape - dtype_name = str(tensor.dtype).lstrip("torch.") + dtype_name = str(tensor.dtype).replace("torch.", "") raw_data = tensor if tensor.dtype == torch.bfloat16: if USE_LEGACY_BFLOAT16: # legacy mode: convert to fp32 diff --git a/hivemind/compression/quantization.py b/hivemind/compression/quantization.py index 0f2e4e098..a1fdc11ea 100644 --- a/hivemind/compression/quantization.py +++ b/hivemind/compression/quantization.py @@ -140,7 +140,7 @@ def quantize( def compress(self, tensor: torch.Tensor, info: CompressionInfo, allow_inplace: bool = False) -> runtime_pb2.Tensor: tensor = tensor.detach() - dtype_name = str(tensor.dtype).lstrip("torch.") + dtype_name = str(tensor.dtype).replace("torch.", "") if tensor.dtype == torch.bfloat16: tensor = tensor.to(torch.float32)