Skip to content

Commit

Permalink
add __eq__ method to SparseEmbedding (#7574)
Browse files Browse the repository at this point in the history
* add __eq__ method to SparseEmbedding

* reno

* improve reno
  • Loading branch information
anakin87 authored Apr 23, 2024
1 parent 958f1eb commit 19a46af
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 0 deletions.
3 changes: 3 additions & 0 deletions haystack/dataclasses/sparse_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@ def __init__(self, indices: List[int], values: List[float]):
self.indices = indices
self.values = values

def __eq__(self, other):
return self.indices == other.indices and self.values == other.values

def to_dict(self):
"""
Convert the SparseEmbedding object to a dictionary.
Expand Down
4 changes: 4 additions & 0 deletions releasenotes/notes/sparse-emb-eq-773ef04ae3ed83ea.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
---
enhancements:
- |
Add an `__eq__` method to `SparseEmbedding` class to compare two `SparseEmbedding` objects.
8 changes: 8 additions & 0 deletions test/dataclasses/test_sparse_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,3 +21,11 @@ def test_from_dict(self):
se = SparseEmbedding.from_dict({"indices": [0, 2, 4], "values": [0.1, 0.2, 0.3]})
assert se.indices == [0, 2, 4]
assert se.values == [0.1, 0.2, 0.3]

def test_eq(self):
se1 = SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3])
se2 = SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.3])
assert se1 == se2

se3 = SparseEmbedding(indices=[0, 2, 4], values=[0.1, 0.2, 0.4])
assert se1 != se3

0 comments on commit 19a46af

Please sign in to comment.