diff --git a/src/litdata/streaming/serializers.py b/src/litdata/streaming/serializers.py index 11c6de03..cc4b107f 100644 --- a/src/litdata/streaming/serializers.py +++ b/src/litdata/streaming/serializers.py @@ -184,7 +184,13 @@ def deserialize(self, data: bytes) -> torch.Tensor: shape = [] for shape_idx in range(shape_size): shape.append(np.frombuffer(data[8 + 4 * shape_idx : 8 + 4 * (shape_idx + 1)], np.uint32).item()) - tensor = torch.frombuffer(data[8 + 4 * (shape_idx + 1) : len(data)], dtype=dtype) + idx_start = 8 + 4 * (shape_idx + 1) + idx_end = len(data) + if idx_end > idx_start: + tensor = torch.frombuffer(data[idx_start:idx_end], dtype=dtype) + else: + assert idx_start == idx_end, "The starting index should never be greater than end ending index." + tensor = torch.empty(shape, dtype=dtype) shape = torch.Size(shape) if tensor.shape == shape: return tensor @@ -211,7 +217,11 @@ def serialize(self, item: torch.Tensor) -> Tuple[bytes, Optional[str]]: def deserialize(self, data: bytes) -> torch.Tensor: assert self._dtype - return torch.frombuffer(data, dtype=self._dtype) + if len(data) > 0: + tensor = torch.frombuffer(data, dtype=self._dtype) + else: + tensor = torch.empty((0,), dtype=self._dtype) + return tensor def can_serialize(self, item: torch.Tensor) -> bool: return isinstance(item, torch.Tensor) and type(item) == torch.Tensor and len(item.shape) == 1 diff --git a/tests/streaming/test_serializer.py b/tests/streaming/test_serializer.py index d7815ff0..ec55fa32 100644 --- a/tests/streaming/test_serializer.py +++ b/tests/streaming/test_serializer.py @@ -244,3 +244,31 @@ class CustomSerializer(NoHeaderTensorSerializer): assert isinstance(serializers["no_header_tensor"], CustomSerializer) assert isinstance(serializers["custom"], CustomSerializer) + + +def test_deserialize_empty_tensor(): + serializer = TensorSerializer() + t = torch.ones((0, 3)).int() + data, name = serializer.serialize(t) + new_t = serializer.deserialize(data) + assert torch.equal(t, new_t) + + t = torch.ones((0, 3)).float() + data, name = serializer.serialize(t) + new_t = serializer.deserialize(data) + assert torch.equal(t, new_t) + + +def test_deserialize_empty_no_header_tensor(): + serializer = NoHeaderTensorSerializer() + t = torch.ones((0,)).int() + data, name = serializer.serialize(t) + serializer.setup(name) + new_t = serializer.deserialize(data) + assert torch.equal(t, new_t) + + t = torch.ones((0,)).float() + data, name = serializer.serialize(t) + serializer.setup(name) + new_t = serializer.deserialize(data) + assert torch.equal(t, new_t)