From 19a46af9da0a386e846202d6eaab249dfbc150ba Mon Sep 17 00:00:00 2001 From: Stefano Fiorucci Date: Tue, 23 Apr 2024 19:03:41 +0200 Subject: [PATCH] add `__eq__` method to `SparseEmbedding` (#7574) * add __eq__ method to SparseEmbedding * reno * improve reno --- haystack/dataclasses/sparse_embedding.py | 3 +++ releasenotes/notes/sparse-emb-eq-773ef04ae3ed83ea.yaml | 4 ++++ test/dataclasses/test_sparse_embedding.py | 8 ++++++++ 3 files changed, 15 insertions(+) create mode 100644 releasenotes/notes/sparse-emb-eq-773ef04ae3ed83ea.yaml diff --git a/haystack/dataclasses/sparse_embedding.py b/haystack/dataclasses/sparse_embedding.py index 5fbfc2bbd9..39355af598 100644 --- a/haystack/dataclasses/sparse_embedding.py +++ b/haystack/dataclasses/sparse_embedding.py @@ -20,6 +20,9 @@ def __init__(self, indices: List[int], values: List[float]): self.indices = indices self.values = values + def __eq__(self, other): + return self.indices == other.indices and self.values == other.values + def to_dict(self): """ Convert the SparseEmbedding object to a dictionary. diff --git a/releasenotes/notes/sparse-emb-eq-773ef04ae3ed83ea.yaml b/releasenotes/notes/sparse-emb-eq-773ef04ae3ed83ea.yaml new file mode 100644 index 0000000000..4074680bcf --- /dev/null +++ b/releasenotes/notes/sparse-emb-eq-773ef04ae3ed83ea.yaml @@ -0,0 +1,4 @@ +--- +enhancements: + - | + Add an `__eq__` method to `SparseEmbedding` class to compare two `SparseEmbedding` objects. diff --git a/test/dataclasses/test_sparse_embedding.py b/test/dataclasses/test_sparse_embedding.py index f3fc889aa5..0617610189 100644 --- a/test/dataclasses/test_sparse_embedding.py +++ b/test/dataclasses/test_sparse_embedding.py @@ -21,3 +21,11 @@ def test_from_dict(self): se = SparseEmbedding.from_dict({"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]}) assert se.indices == [0, 2, 4] assert se.values == [0.1, 0.2, 0.3] + + def test_eq(self): + se1 = SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3]) + se2 = SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3]) + assert se1 == se2 + + se3 = SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.4]) + assert se1 != se3