Skip to content

Commit

Permalink
Merge pull request #121 from Semantics-of-Sustainability/feature/jsca…
Browse files Browse the repository at this point in the history
…tter

Add JScatter visualization
  • Loading branch information
carschno authored Oct 18, 2024
2 parents 07af48e + 6689103 commit 06a8766
Show file tree
Hide file tree
Showing 9 changed files with 605 additions and 394 deletions.
606 changes: 223 additions & 383 deletions notebooks/term_frequency.ipynb

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,24 @@ include_package_data = True
packages = find:
install_requires =
accelerate~=0.22.0
jupyter-scatter~=0.19.0
kneed~=0.8.5
seaborn~=0.13.0
transformers~=4.39.0
torch>=2.2.2
umap-learn~=0.5.4
wizmap~=0.1.2
matplotlib~=3.7.2 # Explicit version set for Windows build
python-dateutil~=2.9.0.post0
sacremoses~=0.0.53 # Required for XLM models
scikit-learn~=1.3.0
sentence_splitter~=1.4.0
stanza~=1.7.0
chromadb~=0.4.22
weaviate-client~=4.6.5
wtpsplit~=2.0.5
# Wtpsplit does not work with huggingface-hub 0.26 (https://github.com/segment-any-text/wtpsplit/issues/135)
huggingface-hub~=0.25.0
pydantic~=2.8.2
# Required for UMAP plotting:
pandas
Expand Down
9 changes: 8 additions & 1 deletion tempo_embeddings/text/corpus.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,10 +545,17 @@ def to_dataframe(self) -> pd.DataFrame:
"""
# TODO: add option for including compressed or full embedding or no centroid distances
# TODO: merge with hover_datas()

corpus_properties = {"corpus": self.label}
if self.top_words:
corpus_properties["top words"] = self.top_words_string()

return pd.DataFrame(
(
passage.to_dict() | {"distance_to_centroid": distance}
passage.to_dict()
| corpus_properties
| {"distance_to_centroid": distance}
for passage, distance in zip(
self.passages,
self.distances(normalize=True, use_2d_embeddings=True),
Expand Down
12 changes: 12 additions & 0 deletions tempo_embeddings/text/passage.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,17 @@
from typing import Any, Iterable, Optional

import pandas as pd
from dateutil.parser import parse

from .highlighting import Highlighting


class Passage:
"""A text passage with optional metadata and highlighting."""

_TYPE_CONVERTERS = {"year": int, "date": parse}
"""When exporting to dict, convert these fields using the specified converter."""

def __init__(
self,
text: str,
Expand Down Expand Up @@ -178,6 +182,8 @@ def hover_data(

def to_dict(self) -> pd.DataFrame:
"""Returns a dictionary representation of the passage."""

# TODO: merge with hover_data()
d = {
"text": self.text,
"ID_DB": self.get_unique_id(),
Expand Down Expand Up @@ -339,6 +345,12 @@ def from_weaviate_record(cls, _object) -> "Passage":
"""

metadata = _object.properties

# convert date types; can be removed once Weaviate index has the right types
for field in cls._TYPE_CONVERTERS:
if field in metadata:
metadata[field] = cls._TYPE_CONVERTERS[field](metadata[field])

text = metadata.pop("passage")
highlighting = (
Highlighting.from_string(metadata.pop("highlighting"))
Expand Down
295 changes: 295 additions & 0 deletions tempo_embeddings/visualization/jscatter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,295 @@
import logging
from typing import Optional

import jscatter
import pandas as pd
from IPython.display import clear_output, display
from ipywidgets import widgets

from ..settings import STOPWORDS
from ..text.corpus import Corpus
from ..text.keyword_extractor import KeywordExtractor


class JScatterVisualizer:
"""A class for creating interactive scatter plots with Jupyter widgets."""

def __init__(
self,
corpus,
categorical_fields: list[str] = ["newspaper", "label"],
continuous_filter_fields: list[str] = ["year"],
tooltip_fields: list[str] = ["year", "text", "label", "top words", "newspaper"],
fillna: dict[str, str] = {"newspaper": "NRC"},
color_by: str = "label",
keyword_extractor: Optional[KeywordExtractor] = None,
):
self._keyword_extractor = keyword_extractor or KeywordExtractor(
corpus, exclude_words=STOPWORDS
)
self._categorical_fields = categorical_fields
self._continuous_filter_fields = continuous_filter_fields
self._tooltip_fields = tooltip_fields
self._fillna = fillna
self._color_by = color_by

self._plot = PlotWidgets(
[corpus],
self._categorical_fields,
self._continuous_filter_fields,
self._tooltip_fields,
self._fillna,
self._color_by,
)
self._cluster_plot = None
"""Index of the current plot being visualized."""

@property
def clusters(self):
if self._cluster_plot is None:
logging.warning("No clusters have been computed yet.")
return None
else:
return self._cluster_plot._corpora

def _cluster_button(self) -> widgets.Button:
"""Create a button for clustering the data."""

# TODO: add selectors for clustering parameters

def cluster(button):
# TODO: add clustering parameters

if self._cluster_plot is None:
# Initialize clustered plot
clusters = list(self._plot._corpora[0].cluster())

if self._keyword_extractor:
for c in clusters:
c.top_words = self._keyword_extractor.top_words(c)
self._cluster_plot = PlotWidgets(
clusters,
self._categorical_fields,
self._continuous_filter_fields,
self._tooltip_fields,
self._fillna,
self._color_by,
)

widgets = self._cluster_plot._widgets + [self._return_button()]

display(*widgets, clear=True)

button = widgets.Button(
description="Cluster",
disabled=False,
button_style="", # 'success', 'info', 'warning', 'danger' or ''
tooltip="Cluster the data",
# icon="check", # (FontAwesome names without the `fa-` prefix)
)
button.on_click(cluster)

return button

def _return_button(self) -> widgets.Button:
def _return(button):
clear_output(wait=True)
widgets = self._plot._widgets + [self._cluster_button()]

display(*widgets, clear=True)

button = widgets.Button(
description="Return",
disabled=False,
button_style="", # 'success', 'info', 'warning', 'danger' or ''
tooltip="Return to initial view",
)
button.on_click(_return)

return button

def visualize(self) -> list[widgets.Widget]:
"""Display the initial visualization."""
widgets = self._plot._widgets + [self._cluster_button()]
display(*widgets)
return widgets


class PlotWidgets:
"""A class for holding the widgets for a plot."""

def __init__(
self,
corpora: list[Corpus],
categorical_fields: list[str],
continuous_filter_fields: list[str],
tooltip_fields: list[str],
fillna: dict[str, str],
color_by: str,
):
"""Create a PlotWidgets object to create the widgets for a JScatterVisualizer.
Args:
corpus (Corpus): The corpus to visualize.
categorical_fields (list[str], optional): The categorical fields to filter on.
continuous_filter_fields (list[str], optional): The continuous fields to filter on.
tooltip_fields (list[str], optional): The fields to show in the tooltip.
fillna (dict[str, str], optional): The values to fill NaN values with.
color_by (str, optional): The field to color the scatter plot by.
"""

self._indices: dict[str, pd.RangeIndex] = {}
"""Keeps track of filtered indices per filter field."""

self._corpora: list[Corpus] = corpora
self._fillna = fillna
self._tooltip_fields = tooltip_fields
self._color_by = color_by

self._categorical_fields = categorical_fields
self._continuous_fields = continuous_filter_fields

self._init_scatter()
self._init_widgets()

def __len__(self):
return len(self._corpora)

def _init_dataframe(self) -> pd.DataFrame:
"""Create a DataFrame from the corpora."""

self._df = (
pd.concat(
c.to_dataframe().assign(label=c.label).assign(outlier=c.is_outliers())
for c in self._corpora
)
.reset_index()
.fillna(self._fillna)
.convert_dtypes()
)
return self._df

def _init_scatter(self) -> jscatter.Scatter:
"""Create the scatter plot."""

self._scatter = (
jscatter.Scatter(data=self._init_dataframe(), x="x", y="y")
.color(by=self._color_by)
.axes(False)
.tooltip(True, properties=self._tooltip_fields)
)
return self._scatter

def _init_widgets(self) -> tuple[jscatter.Scatter, widgets.HBox, widgets.HBox]:
"""Create the widgets for filtering the scatter plot."""

category_filters: list[widgets.Widget] = [
widget
for field in self._categorical_fields
for widget in self._category_field_filter(field) or []
]
continuous_filters: list[widgets.Widget] = [
widget
for field in self._continuous_fields
for widget in self._continuous_field_filter(field) or []
]

self._widgets: tuple[jscatter.Scatter, widgets.HBox, widgets.HBox] = [
self._scatter.show(),
widgets.HBox(continuous_filters),
widgets.HBox(category_filters),
]

return self._widgets

def _category_field_filter(
self, field: str
) -> Optional[tuple[widgets.SelectMultiple, widgets.Output]]:
"""Create a selection widget for filtering on a categorical field.
Args:
field (str): The field to filter on.
Returns:
widgets.VBox: A widget containing the selection widget and the output widget
"""
# FIXME: this not work for filtering by "top words"

if field not in self._df.columns:
logging.warning(f"Categorical field '{field}' not found, ignoring")
return

options = self._df[field].unique().tolist()

if len(options) > 1:
selector = widgets.SelectMultiple(
options=options,
value=options, # TODO: filter out outliers
description=field,
layout={"width": "max-content"},
rows=min(len(options), 10),
)

selector_output = widgets.Output()

def handle_change(change):
self._filter(field, self._df.query(f"{field} in @change.new").index)

selector.observe(handle_change, names="value")

return selector, selector_output
else:
logging.debug(f"Skipping field {field} with only {len(options)} option(s)")

def _continuous_field_filter(
self, field: str
) -> Optional[tuple[widgets.SelectionRangeSlider, widgets.Output]]:
"""Create a selection widget for filtering on a continuous field.
Args:
field (str): The field to filter on.
Returns:
widgets.VBox: A widget containing a RangeSlider widget and the output widget
"""
if field not in self._df.columns:
logging.warning(f"Categorical field '{field}' not found, ignoring")
return

min_year = self._df[field].min()
max_year = self._df[field].max()

selection = widgets.SelectionRangeSlider(
options=[str(i) for i in range(min_year, max_year + 1)],
index=(0, max_year - min_year),
description=field,
continuous_update=True,
)

selection_output = widgets.Output()

def handle_slider_change(change):
start = int(change.new[0]) # noqa: F841
end = int(change.new[1]) # noqa: F841

self._filter(field, self._df.query("year > @start & year < @end").index)

selection.observe(handle_slider_change, names="value")

return selection, selection_output

def _filter(self, field, index):
"""Filter the scatter plot based on the given field and index.
Intersect the indices per field to get the final index to filter in the plot.
Args:
field (str): The field to filter on.
index (pd.RangeIndex): The index listing the rows to keep for this field
"""
self._indices[field] = index

index = self._df.index
for _index in self._indices.values():
index = index.intersection(_index)

self._scatter.filter(index)
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def test_passages():
Passage(
f"test text {str(i)}",
# FIXME: year should be int type
metadata={"provenance": "test_file", "year": str(1950 + i)},
metadata={"provenance": "test_file", "year": 1950 + i},
highlighting=Highlighting(1, 3),
# TODO: make this deterministic for testing
embedding=np.random.rand(768).tolist(),
Expand Down
Loading

0 comments on commit 06a8766

Please sign in to comment.