Skip to content

Commit

Permalink
Merge pull request #142 from Semantics-of-Sustainability/fix/cluster_…
Browse files Browse the repository at this point in the history
…label

Fix/cluster label
  • Loading branch information
carschno authored Dec 3, 2024
2 parents 1f0c5ef + b430aba commit 52aa826
Show file tree
Hide file tree
Showing 5 changed files with 248 additions and 95 deletions.
252 changes: 203 additions & 49 deletions notebooks/term_frequency.ipynb

Large diffs are not rendered by default.

35 changes: 14 additions & 21 deletions tempo_embeddings/text/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,14 +65,19 @@ def __add__(self, other: "Corpus") -> "Corpus":
logging.info("No UMAP model has been computed.")
umap = None

if self.top_words or other.top_words:
logging.warning(
"Dropping existing top words: %s, %s", self.top_words, other.top_words
)
if self._label == other._label:
label = self.label
elif self._label and other._label:
label = " + ".join((self.label, other.label))
elif not self._label and not other._label:
label = None
elif self._label:
label = self.label
elif other._label:
label = other.label
else:
raise AssertionError("Uncovered label combination; this is a bug.")

label = " + ".join(
(label for label in (self.label, other.label) if label != str(None))
)
return Corpus(
self._passages + other._passages, label=label or None, umap_model=umap
)
Expand Down Expand Up @@ -354,12 +359,8 @@ def cluster(
cluster_passages[cluster].append(passage)

for cluster, passages in cluster_passages.items():
label = "; ".join(
[
self.label,
OUTLIERS_LABEL if cluster == -1 else f"cluster {cluster}",
]
)
label = self.label + "; " if self.label else ""
label += OUTLIERS_LABEL if cluster == -1 else f"cluster {cluster}"
yield Corpus(tuple(passages), label=label, umap_model=self._umap)

def compress_embeddings(self) -> np.ndarray:
Expand Down Expand Up @@ -695,11 +696,3 @@ def from_csv_stream(
yield Corpus(batch)
else:
yield Corpus(tuple(passages))

@classmethod
def sum(cls, *corpora) -> "Corpus":
labels = Counter(c.label for c in corpora)
if any(count > 1 for count in labels.values()):
raise ValueError("Corpora with the same label cannot be merged.")

return sum(corpora, Corpus())
7 changes: 2 additions & 5 deletions tempo_embeddings/text/keyword_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from sklearn.exceptions import NotFittedError
from sklearn.feature_extraction.text import TfidfVectorizer

from ..settings import OUTLIERS_LABEL, STOPWORDS
from ..settings import STOPWORDS
from .corpus import Corpus


Expand Down Expand Up @@ -140,10 +140,7 @@ def _top_words(

exclude_words: set[str] = {word.casefold() for word in (exclude_words or [])}

if corpus.label in (-1, OUTLIERS_LABEL):
words = [OUTLIERS_LABEL]
else:
words = self._tf_idf_words(corpus, use_2d_embeddings=use_2d_embeddings)
words = self._tf_idf_words(corpus, use_2d_embeddings=use_2d_embeddings)

for word, score in words:
if (
Expand Down
17 changes: 8 additions & 9 deletions tempo_embeddings/visualization/jscatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def __init__(
self._umap = corpora[0].umap
"""Common UMAP model; assuming all corpora have the same model."""

merged_corpus = Corpus.sum(*corpora)
merged_corpus = sum(corpora, Corpus())
self._keyword_extractor = keyword_extractor or KeywordExtractor(
merged_corpus, exclude_words=STOPWORDS
)
Expand Down Expand Up @@ -290,14 +290,14 @@ def _show_top_words(b):

def _category_field_filter(
self, field: str
) -> Optional[tuple[widgets.SelectMultiple, widgets.Output]]:
) -> Optional[widgets.SelectMultiple]:
"""Create a selection widget for filtering on a categorical field.
Args:
field (str): The field to filter on.
Returns:
widgets.VBox: A widget containing the selection widget and the output widget
widgets.SelectMultiple: A widget containing the selection widget or None if the field is not suitable for filtering.
"""
# FIXME: this does not work for filtering by "top words"

Expand All @@ -324,22 +324,21 @@ def handle_change(change):

selector.observe(handle_change, names="value")

return selector
widget = selector
else:
logging.info(
f"Skipping field {field} with {len(options)} option(s) for filtering."
)
return
widget = None
return widget

def _continuous_field_filter(
self, field: str
) -> Optional[tuple[widgets.SelectionRangeSlider, widgets.Output]]:
def _continuous_field_filter(self, field: str) -> widgets.SelectionRangeSlider:
"""Create a selection widget for filtering on a continuous field.
Args:
field (str): The field to filter on.
Returns:
widgets.VBox: A widget containing a RangeSlider widget and the output widget
widgets.SelectionRangeSlider: A widget containing a RangeSlider
"""
if field not in self._df.columns:
raise ValueError(f"Field '{field}' not found.")
Expand Down
32 changes: 21 additions & 11 deletions tests/unit/text/test_corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,22 @@ def test_add(self, test_passages):
expected = Corpus(test_passages[:2], None, umap_model=None)
assert Corpus([test_passages[0]]) + Corpus([test_passages[1]]) == expected

@pytest.mark.parametrize(
"label1, label2, expected",
[
(None, None, "None"),
("label1", None, "label1"),
(None, "label2", "label2"),
("label1", "label2", "label1 + label2"),
("label", "label", "label"),
],
)
def test_add_label(self, label1, label2, expected):
corpus1 = Corpus(label=label1)
corpus2 = Corpus(label=label2)

assert (corpus1 + corpus2).label == expected

def test_add_umap_fitted(self, corpus):
corpus2 = Corpus([Passage("test {i}") for i in range(5)])

Expand Down Expand Up @@ -149,10 +165,11 @@ def test_cluster(self, caplog, n_passages, max_clusters, min_cluster_size):
for cluster in clusters:
assert all(passage in corpus.passages for passage in cluster.passages)
assert len(cluster) >= min_cluster_size or cluster.is_outliers()
assert (
cluster.label.startswith("TestCorpus; cluster ")
or cluster.label == "TestCorpus; Outliers"
)

labels = sorted((cluster.label for cluster in clusters))
assert labels[0] == "TestCorpus; Outliers"
for i, label in enumerate(labels[1:]):
assert label == f"TestCorpus; cluster {i}"

def test_is_outliers(self, corpus):
assert not corpus.is_outliers()
Expand Down Expand Up @@ -551,10 +568,3 @@ def test_fit_umap(self, corpus):
with pytest.raises(RuntimeError):
corpus.passages[0]._embedding = None
corpus._fit_umap()

def test_sum(self, corpus):
with pytest.raises(ValueError):
Corpus.sum(corpus, corpus)

corpus2 = Corpus([Passage("test")])
assert Corpus.sum(corpus, corpus2) == sum([corpus, corpus2], Corpus())

0 comments on commit 52aa826

Please sign in to comment.