Skip to content

Commit

Permalink
Fix label combination in Corpus.__add__().
Browse files Browse the repository at this point in the history
  • Loading branch information
carschno committed Dec 3, 2024
1 parent c6fb49f commit b430aba
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 5 deletions.
12 changes: 7 additions & 5 deletions tempo_embeddings/text/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,16 +65,18 @@ def __add__(self, other: "Corpus") -> "Corpus":
logging.info("No UMAP model has been computed.")
umap = None

if self.label == other.label:
if self._label == other._label:
label = self.label
elif not self.label and not other.label:
elif self._label and other._label:
label = " + ".join((self.label, other.label))
elif not self._label and not other._label:
label = None
elif self.label:
elif self._label:
label = self.label
elif other.label:
elif other._label:
label = other.label
else:
label = " + ".join((self.label, other.label))
raise AssertionError("Uncovered label combination; this is a bug.")

return Corpus(
self._passages + other._passages, label=label or None, umap_model=umap
Expand Down
16 changes: 16 additions & 0 deletions tests/unit/text/test_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,22 @@ def test_add(self, test_passages):
expected = Corpus(test_passages[:2], None, umap_model=None)
assert Corpus([test_passages[0]]) + Corpus([test_passages[1]]) == expected

@pytest.mark.parametrize(
"label1, label2, expected",
[
(None, None, "None"),
("label1", None, "label1"),
(None, "label2", "label2"),
("label1", "label2", "label1 + label2"),
("label", "label", "label"),
],
)
def test_add_label(self, label1, label2, expected):
corpus1 = Corpus(label=label1)
corpus2 = Corpus(label=label2)

assert (corpus1 + corpus2).label == expected

def test_add_umap_fitted(self, corpus):
corpus2 = Corpus([Passage("test {i}") for i in range(5)])

Expand Down

0 comments on commit b430aba

Please sign in to comment.