From c2ec3f5fde2138466182ea7f5cfa7612e3fb2d6e Mon Sep 17 00:00:00 2001 From: Massimiliano Pippi Date: Wed, 4 Oct 2023 17:23:12 +0200 Subject: [PATCH] feat: add File type to preview package (#5873) * add Blob type * review feedback * fix tests and naming * Update add-blob-type-2a9476a39841f54d.yaml * removed unused import --------- Co-authored-by: Stefano Fiorucci <44616784+anakin87@users.noreply.github.com> --- haystack/preview/dataclasses/__init__.py | 3 +- haystack/preview/dataclasses/byte_stream.py | 37 +++++++++++++++++++ .../notes/add-blob-type-2a9476a39841f54d.yaml | 5 +++ test/preview/dataclasses/test_byte_stream.py | 33 +++++++++++++++++ 4 files changed, 77 insertions(+), 1 deletion(-) create mode 100644 haystack/preview/dataclasses/byte_stream.py create mode 100644 releasenotes/notes/add-blob-type-2a9476a39841f54d.yaml create mode 100644 test/preview/dataclasses/test_byte_stream.py diff --git a/haystack/preview/dataclasses/__init__.py b/haystack/preview/dataclasses/__init__.py index 5a0d8489f8..6873ac0ccb 100644 --- a/haystack/preview/dataclasses/__init__.py +++ b/haystack/preview/dataclasses/__init__.py @@ -1,4 +1,5 @@ from haystack.preview.dataclasses.document import Document from haystack.preview.dataclasses.answer import ExtractedAnswer, GeneratedAnswer, Answer +from haystack.preview.dataclasses.byte_stream import ByteStream -__all__ = ["Document", "ExtractedAnswer", "GeneratedAnswer", "Answer"] +__all__ = ["Document", "ExtractedAnswer", "GeneratedAnswer", "Answer", "ByteStream"] diff --git a/haystack/preview/dataclasses/byte_stream.py b/haystack/preview/dataclasses/byte_stream.py new file mode 100644 index 0000000000..fe006fcf85 --- /dev/null +++ b/haystack/preview/dataclasses/byte_stream.py @@ -0,0 +1,37 @@ +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, Any + + +@dataclass(frozen=True) +class ByteStream: + """ + Base data class representing a binary object in the Haystack API. + """ + + data: bytes + metadata: Dict[str, Any] = field(default_factory=dict, hash=False) + + def to_file(self, destination_path: Path): + with open(destination_path, "wb") as fd: + fd.write(self.data) + + @classmethod + def from_file_path(cls, filepath: Path) -> "ByteStream": + """ + Create a ByteStream from the contents read from a file. + + :param filepath: A valid path to a file. + """ + with open(filepath, "rb") as fd: + return cls(data=fd.read()) + + @classmethod + def from_string(cls, text: str, encoding: str = "utf-8") -> "ByteStream": + """ + Create a ByteStream encoding a string. + + :param text: The string to encode + :param encoding: The encoding used to convert the string into bytes + """ + return cls(data=text.encode(encoding)) diff --git a/releasenotes/notes/add-blob-type-2a9476a39841f54d.yaml b/releasenotes/notes/add-blob-type-2a9476a39841f54d.yaml new file mode 100644 index 0000000000..163d9631ef --- /dev/null +++ b/releasenotes/notes/add-blob-type-2a9476a39841f54d.yaml @@ -0,0 +1,5 @@ +--- +preview: + - | + Add ByteStream type to send binary raw data across components + in a pipeline. diff --git a/test/preview/dataclasses/test_byte_stream.py b/test/preview/dataclasses/test_byte_stream.py new file mode 100644 index 0000000000..05d40eb79e --- /dev/null +++ b/test/preview/dataclasses/test_byte_stream.py @@ -0,0 +1,33 @@ +import io + +from haystack.preview.dataclasses import ByteStream + +import pytest + + +@pytest.mark.unit +def test_from_file_path(tmp_path, request): + test_bytes = "Hello, world!\n".encode() + test_path = tmp_path / request.node.name + with open(test_path, "wb") as fd: + assert fd.write(test_bytes) + + b = ByteStream.from_file_path(test_path) + assert b.data == test_bytes + + +@pytest.mark.unit +def test_from_string(): + test_string = "Hello, world!" + b = ByteStream.from_string(test_string) + assert b.data.decode() == test_string + + +@pytest.mark.unit +def test_to_file(tmp_path, request): + test_str = "Hello, world!\n" + test_path = tmp_path / request.node.name + + ByteStream(test_str.encode()).to_file(test_path) + with open(test_path, "rb") as fd: + assert fd.read().decode() == test_str