Skip to content

Commit

Permalink
Add SHOW ALL category, remove unused Output widgets.
Browse files Browse the repository at this point in the history
  • Loading branch information
carschno committed Dec 3, 2024
1 parent 8a51e5e commit cdf2810
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 32 deletions.
33 changes: 16 additions & 17 deletions tempo_embeddings/visualization/jscatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,20 +163,18 @@ def with_corpora(self, corpora: list[Corpus], **kwargs) -> "JScatterVisualizer":
def visualize(self) -> None:
"""Display the initial visualization."""
continuous_filters: list[widgets.Widget] = [
widget
self._plot_widgets._continuous_field_filter(field)
for field in self._continuous_fields
for widget in self._plot_widgets._continuous_field_filter(field) or []
]
category_filters: list[widgets.Widget] = [
widget
self._plot_widgets._category_field_filter(field)
for field in self._categorical_fields
if field not in self.__EXCLUDE_FILTER_FIELDS
for widget in self._plot_widgets._category_field_filter(field) or []
]

_widgets: list[widgets.Widget] = [self._scatter.show()] + [
widgets.HBox(continuous_filters),
widgets.HBox(category_filters),
widgets.HBox([widget for widget in category_filters if widget is not None]),
self._plot_widgets._cluster_button(),
self._plot_widgets._export_button(),
# self._top_words_button(),
Expand All @@ -186,6 +184,8 @@ def visualize(self) -> None:
class PlotWidgets:
"""A class for generating the widgets for a plot."""

__SHOW_ALL: str = "<SHOW ALL>"

def __init__(self, visualizer: "JScatterVisualizer"):
self._visualizer = visualizer
self._df = self._visualizer._df
Expand Down Expand Up @@ -304,26 +304,27 @@ def _category_field_filter(
if field not in self._df.columns:
raise ValueError(f"'{field}' does not exist.")
options = self._df[field].dropna().unique().tolist()

if field in self._df.columns and 1 < len(options) <= 50:
selector = widgets.SelectMultiple(
options=options,
value=options, # TODO: filter out outliers
options=[self.__SHOW_ALL] + options,
value=[self.__SHOW_ALL], # TODO: filter out outliers
description=field,
layout={"width": "max-content"},
rows=min(len(options), 10),
rows=min(len(options) + 1, 10),
)

selector_output = widgets.Output()

def handle_change(change):
filtered = self._df.loc[
self._df[field].isin(change.new) | self._df[field].isna()
]
if self.__SHOW_ALL in change.new:
filtered = self._df
else:
filtered = self._df.loc[self._df[field].isin(change.new)]

self._filter(field, filtered.index)

selector.observe(handle_change, names="value")

return selector, selector_output
return selector
else:
logging.info(
f"Skipping field {field} with {len(options)} option(s) for filtering."
Expand Down Expand Up @@ -353,8 +354,6 @@ def _continuous_field_filter(
continuous_update=True,
)

selection_output = widgets.Output()

def handle_slider_change(change):
filtered = self._df.loc[
(
Expand All @@ -367,7 +366,7 @@ def handle_slider_change(change):

selection.observe(handle_slider_change, names="value")

return selection, selection_output
return selection

def _filter(self, field, index):
"""Filter the scatter plot based on the given field and index.
Expand Down
27 changes: 12 additions & 15 deletions tests/unit/visualization/test_jscatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,7 @@
from unittest import mock

import pytest
from ipywidgets.widgets import (
Button,
HBox,
Output,
SelectionRangeSlider,
SelectMultiple,
)
from ipywidgets.widgets import Button, HBox, SelectionRangeSlider, SelectMultiple

from tempo_embeddings.text.corpus import Corpus
from tempo_embeddings.visualization.jscatter import JScatterVisualizer
Expand Down Expand Up @@ -43,8 +37,6 @@ def test_visualize(
exception,
):
widget_types = [HBox, HBox, HBox, Button, DownloadButton]
cat_types = [SelectionRangeSlider, Output]
cont_types = [SelectMultiple, Output]

visualizer = JScatterVisualizer(
[corpus],
Expand All @@ -54,14 +46,19 @@ def test_visualize(
with exception or does_not_raise():
visualizer.visualize()

widgets = mock_display.call_args.args
top_widgets = mock_display.call_args.args
"""Top-level widgets passed to the display() call."""

categorical_filters = widgets[1].children
continous_filters = widgets[2].children
categorical_filters = top_widgets[1].children
continous_filters = top_widgets[2].children

assert [type(w) for w in widgets] == widget_types
assert [type(w) for w in categorical_filters] == cat_types * expected_cat
assert [type(w) for w in continous_filters] == cont_types * expected_cont
assert [type(w) for w in top_widgets] == widget_types
assert [type(w) for w in categorical_filters] == [
SelectionRangeSlider
] * expected_cat
assert [type(w) for w in continous_filters] == [
SelectMultiple
] * expected_cont

@pytest.mark.parametrize(
"tooltip_fields,expected",
Expand Down

0 comments on commit cdf2810

Please sign in to comment.