diff --git a/tempo_embeddings/text/corpus.py b/tempo_embeddings/text/corpus.py index b960e41..0bf1ba8 100644 --- a/tempo_embeddings/text/corpus.py +++ b/tempo_embeddings/text/corpus.py @@ -65,16 +65,18 @@ def __add__(self, other: "Corpus") -> "Corpus": logging.info("No UMAP model has been computed.") umap = None - if self.label == other.label: + if self._label == other._label: label = self.label - elif not self.label and not other.label: + elif self._label and other._label: + label = " + ".join((self.label, other.label)) + elif not self._label and not other._label: label = None - elif self.label: + elif self._label: label = self.label - elif other.label: + elif other._label: label = other.label else: - label = " + ".join((self.label, other.label)) + raise AssertionError("Uncovered label combination; this is a bug.") return Corpus( self._passages + other._passages, label=label or None, umap_model=umap diff --git a/tests/unit/text/test_corpus.py b/tests/unit/text/test_corpus.py index 4ac037c..20cdb09 100644 --- a/tests/unit/text/test_corpus.py +++ b/tests/unit/text/test_corpus.py @@ -35,6 +35,22 @@ def test_add(self, test_passages): expected = Corpus(test_passages[:2], None, umap_model=None) assert Corpus([test_passages[0]]) + Corpus([test_passages[1]]) == expected + @pytest.mark.parametrize( + "label1, label2, expected", + [ + (None, None, "None"), + ("label1", None, "label1"), + (None, "label2", "label2"), + ("label1", "label2", "label1 + label2"), + ("label", "label", "label"), + ], + ) + def test_add_label(self, label1, label2, expected): + corpus1 = Corpus(label=label1) + corpus2 = Corpus(label=label2) + + assert (corpus1 + corpus2).label == expected + def test_add_umap_fitted(self, corpus): corpus2 = Corpus([Passage("test {i}") for i in range(5)])