From b87086b3c9a7746cc4f730e95887d357649a63dc Mon Sep 17 00:00:00 2001 From: Carsten Schnober Date: Tue, 3 Dec 2024 13:32:49 +0100 Subject: [PATCH] Add tests. --- tempo_embeddings/text/corpus.py | 4 ++-- tests/unit/text/test_corpus.py | 4 ++++ tests/unit/text/test_passage.py | 9 +++++++++ 3 files changed, 15 insertions(+), 2 deletions(-) diff --git a/tempo_embeddings/text/corpus.py b/tempo_embeddings/text/corpus.py index 6ca89ed..2e1f0c2 100644 --- a/tempo_embeddings/text/corpus.py +++ b/tempo_embeddings/text/corpus.py @@ -189,8 +189,8 @@ def centroid(self, use_2d_embeddings: bool = True) -> np.ndarray: """The mean for all passage embeddings.""" embeddings = self._select_embeddings(use_2d_embeddings) - if embeddings is None: - raise RuntimeError("No embeddings available.") + assert embeddings is not None, "No embeddings available." + return embeddings.mean(axis=0) def coordinates(self) -> pd.DataFrame: diff --git a/tests/unit/text/test_corpus.py b/tests/unit/text/test_corpus.py index d09592a..346f7e3 100644 --- a/tests/unit/text/test_corpus.py +++ b/tests/unit/text/test_corpus.py @@ -548,6 +548,10 @@ def test_fit_umap(self, corpus): with pytest.raises(RuntimeError): corpus._fit_umap() + with pytest.raises(RuntimeError): + corpus.passages[0]._embedding = None + corpus._fit_umap() + def test_sum(self, corpus): with pytest.raises(ValueError): Corpus.sum(corpus, corpus) diff --git a/tests/unit/text/test_passage.py b/tests/unit/text/test_passage.py index 9816626..cabe989 100644 --- a/tests/unit/text/test_passage.py +++ b/tests/unit/text/test_passage.py @@ -161,6 +161,15 @@ def test_to_dict(self, passage, expected): def test_init(self, text, metadata, expected): assert Passage(text, metadata) == expected + @pytest.mark.parametrize( + "embedding_compressed, exception", + [(None, None), ([1.0, 2.0], None), ([1, 2], pytest.raises(ValueError))], + ) + def test_init_embeddings_compressed(self, embedding_compressed, exception): + with exception or does_not_raise(): + passage = Passage("test", embedding_compressed=embedding_compressed) + assert passage.embedding_compressed == embedding_compressed + @pytest.mark.parametrize( "passage,expected", [