Skip to content

Commit

Permalink
Add JScatterContainer for tabbed display.
Browse files Browse the repository at this point in the history
  • Loading branch information
carschno committed Dec 5, 2024
1 parent b430aba commit d382a83
Showing 1 changed file with 155 additions and 121 deletions.
276 changes: 155 additions & 121 deletions tempo_embeddings/visualization/jscatter.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import csv
import logging
from typing import Any, Optional
from typing import Any, Iterable, Optional

import jscatter
import pandas as pd
from IPython.display import clear_output, display
from IPython.display import display
from ipywidgets import widgets

from ..settings import STOPWORDS
Expand All @@ -13,22 +13,49 @@
from .util import DownloadButton


class JScatterContainer:
"""A container with tabs for JScatterVisualizer objects."""

def __init__(self, corpora: list[Corpus], **kwargs):
"""Create a JScatterContainer object to visualize a list of corpora.
Args:
corpora (list[Corpus]): The corpora to visualize initially.
KwArgs:
Arguments to pass to the visualizer, overriding the current values.
"""
self._tab = widgets.Tab()
self._visualizer = JScatterVisualizer(corpora, container=self, **kwargs)
"""The root visualizer."""

self.add_tab(self._visualizer, **kwargs)

def add_tab(self, visualizer: list[Corpus], **kwargs):
children = list(self._tab.children)
children.append(widgets.VBox(visualizer.get_widgets()))
self._tab.children = children

def visualize(self):
display(self._tab)


class JScatterVisualizer:
"""A class for creating interactive scatter plots with Jupyter widgets."""

__REQUIRED_FIELDS: dict[str, Any] = {"x": pd.Float64Dtype(), "y": pd.Float64Dtype()}
"""Required fields and dtype."""
_DEFAULT_CONTINUOUS_FIELDS: set[str] = {"year"}
_EXCLUDE_FILTER_FIELDS: set[str] = {"month", "day", "year"}
_EXCLUDE_TOOLTIP_FIELDS: set[str] = {"date"}

__DEFAULT_CONTINUOUS_FIELDS: set[str] = {"year"}
__EXCLUDE_FILTER_FIELDS: set[str] = {"month", "day", "year"}
__EXCLUDE_TOOLTIP_FIELDS: set[str] = {"date"}
_REQUIRED_FIELDS: dict[str, Any] = {"x": pd.Float64Dtype(), "y": pd.Float64Dtype()}
"""Required fields and dtype."""

def __init__(
self,
corpora: list[Corpus],
*,
container: JScatterContainer,
categorical_fields: Optional[list[str]] = None,
continuous_filter_fields: list[str] = __DEFAULT_CONTINUOUS_FIELDS,
continuous_filter_fields: list[str] = _DEFAULT_CONTINUOUS_FIELDS,
tooltip_fields: list[str] = None,
color_by: list[str] = ["cluster", "label"],
keyword_extractor: Optional[KeywordExtractor] = None,
Expand All @@ -47,6 +74,8 @@ def __init__(
self._validate_corpora(corpora)
self._corpora = corpora

self._container = container

self._umap = corpora[0].umap
"""Common UMAP model; assuming all corpora have the same model."""

Expand Down Expand Up @@ -75,19 +104,20 @@ def __init__(
f"None of the color_by fields found in corpus: {color_by}"
) from e

self._plot_widgets = self.PlotWidgets(self)
self._scatter = self._plot_widgets._scatter()
self._plot_widgets = self.PlotWidgets(
df=self._df, color_by=self._color_by, tooltip_fields=self._tooltip_fields
)

def _validate_corpora(self, corpora):
for column in self.__REQUIRED_FIELDS:
for column in self._REQUIRED_FIELDS:
if not all(column in c.to_dataframe().columns for c in corpora):
raise ValueError(f"Missing required field '{column}' in corpora.")

def _valid_tooltip_fields(self, tooltip_fields: set[str]) -> set[str]:
return (
set(tooltip_fields)
.intersection(self._df.columns)
.difference(self.__EXCLUDE_TOOLTIP_FIELDS)
.difference(self._EXCLUDE_TOOLTIP_FIELDS)
.difference(
(
column
Expand All @@ -110,7 +140,7 @@ def _init_df(self):
self._df["date"] = self._df["date"].apply(pd.to_datetime)

# Validate required fields
for field, dtype in self.__REQUIRED_FIELDS.items():
for field, dtype in self._REQUIRED_FIELDS.items():
if field not in self._df.columns:
raise ValueError(f"Required field '{field}' not found.")
if self._df[field].dtype != dtype:
Expand All @@ -120,25 +150,12 @@ def _init_df(self):
if self._df[field].isna().any():
raise ValueError(f"Field '{field}' contains NaN values.")

def _selected(self) -> list[int]:
"""Return the indices of currently selected/filtered/all rows.
Returns:
The indices of the selected/filtered/all rows.
"""
if self._scatter.selection().size > 0:
index = self._scatter.selection()
else:
try:
# this should be identical with the intersection of all _indices values
filter_indices = self._scatter.filter()
except AttributeError:
# filter() raises error if it has not been set yet
logging.debug("No filter set.")
index = self._df.index
else:
index = filter_indices if filter_indices.size > 0 else self._df.index
return index
def get_widgets(self) -> list[widgets.Widget]:
"""Create all widgets"""
return self._plot_widgets.get_widgets(
continuous_fields=self._continuous_fields,
categorical_fields=self._categorical_fields,
) + [self._cluster_button()]

def with_corpora(self, corpora: list[Corpus], **kwargs) -> "JScatterVisualizer":
"""Create a new JScatterVisualizer with the given corpora.
Expand All @@ -151,116 +168,81 @@ def with_corpora(self, corpora: list[Corpus], **kwargs) -> "JScatterVisualizer":
Returns:
JScatterVisualizer: A new JScatterVisualizer object.
"""
args = {
visualizer_args = {
"categorical_fields": self._categorical_fields,
"continuous_filter_fields": self._continuous_fields,
"tooltip_fields": self._tooltip_fields,
"color_by": [self._color_by],
"keyword_extractor": self._keyword_extractor,
} | kwargs
return JScatterVisualizer(corpora, **args)

def visualize(self) -> None:
"""Display the initial visualization."""
continuous_filters: list[widgets.Widget] = [
self._plot_widgets._continuous_field_filter(field)
for field in self._continuous_fields
]
category_filters: list[widgets.Widget] = [
self._plot_widgets._category_field_filter(field)
for field in self._categorical_fields
if field not in self.__EXCLUDE_FILTER_FIELDS
]

_widgets: list[widgets.Widget] = [self._scatter.show()] + [
widgets.HBox(continuous_filters),
widgets.HBox([widget for widget in category_filters if widget is not None]),
self._plot_widgets._cluster_button(),
self._plot_widgets._export_button(),
# self._top_words_button(),
]
display(*_widgets)
return JScatterVisualizer(corpora, container=self._container, **visualizer_args)

class PlotWidgets:
"""A class for generating the widgets for a plot."""

__SHOW_ALL: str = "<SHOW ALL>"
def _cluster_button(self) -> widgets.Button:
"""Create a button for clustering the data.
def __init__(self, visualizer: "JScatterVisualizer"):
self._visualizer = visualizer
self._df = self._visualizer._df
This button triggers the creation of a new set of corpora (the clusters) and adds a new visualizer to the JScatterContainer instanceß.
That is why this method is part of the JScatterVisualizer class, not the PlotWidgets class.
"""

self._indices = dict()
"""The indices of the filtered rows per field."""
def cluster(button): # pragma: no cover
# TODO: add clustering parameters

def _scatter(self) -> jscatter.Scatter:
"""Create the scatter plot."""

return (
jscatter.Scatter(data=self._visualizer._df, x="x", y="y")
.color(by=self._visualizer._color_by)
.axes(False)
.tooltip(True, properties=self._visualizer._tooltip_fields)
.legend(True)
clusters = list(
Corpus.from_dataframe(
self._df.iloc[self._plot_widgets.selected()], umap_model=self._umap
).cluster()
)

def _cluster_button(self) -> widgets.Button:
"""Create a button for clustering the data."""

# TODO: add selectors for clustering parameters
if self._keyword_extractor:
for c in clusters:
c.top_words = self._keyword_extractor.top_words(
c, use_2d_embeddings=True
)

def cluster(button): # pragma: no cover
# TODO: add clustering parameters
self._container.add_tab(self.with_corpora(clusters, tooltip_fields=None))

clusters = list(
Corpus.from_dataframe(
self._df.iloc[self._visualizer._selected()],
umap_model=self._visualizer._umap,
).cluster()
)
button = widgets.Button(
description="Cluster",
disabled=False,
button_style="", # 'success', 'info', 'warning', 'danger' or ''
tooltip="Cluster the data",
# icon="check", # (FontAwesome names without the `fa-` prefix)
)
button.on_click(cluster)

if self._visualizer._keyword_extractor:
for c in clusters:
c.top_words = self._visualizer._keyword_extractor.top_words(
c, use_2d_embeddings=True
)
return button

# TODO: visualize in new tab widget
self._visualizer.with_corpora(clusters, tooltip_fields=None).visualize()
class PlotWidgets:
"""A class for generating the widgets for a plot."""

# display(*widgets, clear=True)
__SHOW_ALL: str = "<SHOW ALL>"

button = widgets.Button(
description="Cluster",
disabled=False,
button_style="", # 'success', 'info', 'warning', 'danger' or ''
tooltip="Cluster the data",
# icon="check", # (FontAwesome names without the `fa-` prefix)
)
button.on_click(cluster)
def __init__(
self, *, df: pd.DataFrame, color_by: str, tooltip_fields: set[str]
):
self._df = df
self._color_by = color_by
self._tooltip_fields = tooltip_fields

return button
self._scatter_plot: jscatter.JScatter = self._scatter()

def _return_button(self) -> widgets.Button:
def _return(button):
clear_output(wait=True)
widgets = self._plot._widgets + [self._cluster_button()]
self._indices = dict()
"""The indices of the filtered rows per field."""

display(*widgets, clear=True)
def _scatter(self) -> jscatter.Scatter:
"""Create the scatter plot."""

button = widgets.Button(
description="Return",
disabled=False,
button_style="", # 'success', 'info', 'warning', 'danger' or ''
tooltip="Return to initial view",
return (
jscatter.Scatter(data=self._df, x="x", y="y")
.color(by=self._color_by)
.axes(False)
.tooltip(True, properties=self._tooltip_fields)
.legend(True)
)
button.on_click(_return)

return button

def _export_button(self) -> DownloadButton:
def export_button(self) -> DownloadButton:
def selected_rows():
return self._df.iloc[self._visualizer._selected()].to_csv(
return self._df.iloc[self.selected()].to_csv(
index=False, quoting=csv.QUOTE_ALL
)

Expand All @@ -274,8 +256,8 @@ def _top_words_button(self) -> widgets.Button:
def _show_top_words(b):
# TODO: create a link between self._df and the corpora
# TODO: keep/unify text column names
# Corpus.from_csv_stream(self._df.iloc[self._scatter.selection()].to_csv())
corpus = Corpus.from_dataframe(self._df[self._selected()])
# Corpus.from_csv_stream(self._df.iloc[self._scatter_plot.selection()].to_csv())
corpus = Corpus.from_dataframe(self._df[self.selected()])
top_words = self._keyword_extractor.top_words(corpus)
print(top_words)

Expand All @@ -302,7 +284,7 @@ def _category_field_filter(
# FIXME: this does not work for filtering by "top words"

if field not in self._df.columns:
raise ValueError(f"'{field}' does not exist.")
raise ValueError(f"'{field}' does not exist in the data.")
options = self._df[field].dropna().unique().tolist()

if field in self._df.columns and 1 < len(options) <= 50:
Expand Down Expand Up @@ -382,4 +364,56 @@ def _filter(self, field, index):
for _index in self._indices.values():
index = index.intersection(_index)

self._visualizer._scatter.filter(index)
self._scatter_plot.filter(index)

def selected(self) -> list[int]:
"""Return the indices of currently selected/filtered/all rows.
Returns:
The indices of the selected/filtered/all rows.
"""
if self._scatter_plot.selection().size > 0:
index = self._scatter_plot.selection()
else:
try:
# this should be identical with the intersection of all _indices values
filter_indices = self._scatter_plot.filter()
except AttributeError:
# filter() does not exist if not filter has been set yet
index = self._df.index
else:
index = (
filter_indices if filter_indices.size > 0 else self._df.index
)
return index

def get_widgets(
self, *, continuous_fields: Iterable[str], categorical_fields: Iterable[str]
) -> list[widgets.Widget]:
"""Create all widgets
Args:
continuous_fields (Iterable[str]): The continuous fields to filter on.
categorical_fields (Iterable[str]): The categorical fields to filter on.
Returns:
list[widgets.Widget]: The widgets to display.
"""

continuous_filters: list[widgets.Widget] = [
self._continuous_field_filter(field) for field in continuous_fields
]
category_filters: list[Optional[widgets.Widget]] = [
self._category_field_filter(field)
for field in categorical_fields
if field not in JScatterVisualizer._EXCLUDE_FILTER_FIELDS
]

return [self._scatter_plot.show()] + [
widgets.HBox(continuous_filters),
widgets.HBox(
[widget for widget in category_filters if widget is not None]
),
self.export_button(),
# self._top_words_button(),
]

0 comments on commit d382a83

Please sign in to comment.