From cdf2810ec0827659e632b536280d0c7794f792c1 Mon Sep 17 00:00:00 2001 From: Carsten Schnober Date: Tue, 3 Dec 2024 11:20:13 +0100 Subject: [PATCH] Add SHOW ALL category, remove unused Output widgets. --- tempo_embeddings/visualization/jscatter.py | 33 +++++++++++----------- tests/unit/visualization/test_jscatter.py | 27 ++++++++---------- 2 files changed, 28 insertions(+), 32 deletions(-) diff --git a/tempo_embeddings/visualization/jscatter.py b/tempo_embeddings/visualization/jscatter.py index d84d9f9..50ed8ca 100644 --- a/tempo_embeddings/visualization/jscatter.py +++ b/tempo_embeddings/visualization/jscatter.py @@ -163,20 +163,18 @@ def with_corpora(self, corpora: list[Corpus], **kwargs) -> "JScatterVisualizer": def visualize(self) -> None: """Display the initial visualization.""" continuous_filters: list[widgets.Widget] = [ - widget + self._plot_widgets._continuous_field_filter(field) for field in self._continuous_fields - for widget in self._plot_widgets._continuous_field_filter(field) or [] ] category_filters: list[widgets.Widget] = [ - widget + self._plot_widgets._category_field_filter(field) for field in self._categorical_fields if field not in self.__EXCLUDE_FILTER_FIELDS - for widget in self._plot_widgets._category_field_filter(field) or [] ] _widgets: list[widgets.Widget] = [self._scatter.show()] + [ widgets.HBox(continuous_filters), - widgets.HBox(category_filters), + widgets.HBox([widget for widget in category_filters if widget is not None]), self._plot_widgets._cluster_button(), self._plot_widgets._export_button(), # self._top_words_button(), @@ -186,6 +184,8 @@ def visualize(self) -> None: class PlotWidgets: """A class for generating the widgets for a plot.""" + __SHOW_ALL: str = "" + def __init__(self, visualizer: "JScatterVisualizer"): self._visualizer = visualizer self._df = self._visualizer._df @@ -304,26 +304,27 @@ def _category_field_filter( if field not in self._df.columns: raise ValueError(f"'{field}' does not exist.") options = self._df[field].dropna().unique().tolist() + if field in self._df.columns and 1 < len(options) <= 50: selector = widgets.SelectMultiple( - options=options, - value=options, # TODO: filter out outliers + options=[self.__SHOW_ALL] + options, + value=[self.__SHOW_ALL], # TODO: filter out outliers description=field, layout={"width": "max-content"}, - rows=min(len(options), 10), + rows=min(len(options) + 1, 10), ) - selector_output = widgets.Output() - def handle_change(change): - filtered = self._df.loc[ - self._df[field].isin(change.new) | self._df[field].isna() - ] + if self.__SHOW_ALL in change.new: + filtered = self._df + else: + filtered = self._df.loc[self._df[field].isin(change.new)] + self._filter(field, filtered.index) selector.observe(handle_change, names="value") - return selector, selector_output + return selector else: logging.info( f"Skipping field {field} with {len(options)} option(s) for filtering." @@ -353,8 +354,6 @@ def _continuous_field_filter( continuous_update=True, ) - selection_output = widgets.Output() - def handle_slider_change(change): filtered = self._df.loc[ ( @@ -367,7 +366,7 @@ def handle_slider_change(change): selection.observe(handle_slider_change, names="value") - return selection, selection_output + return selection def _filter(self, field, index): """Filter the scatter plot based on the given field and index. diff --git a/tests/unit/visualization/test_jscatter.py b/tests/unit/visualization/test_jscatter.py index 978578a..bdaf5a7 100644 --- a/tests/unit/visualization/test_jscatter.py +++ b/tests/unit/visualization/test_jscatter.py @@ -2,13 +2,7 @@ from unittest import mock import pytest -from ipywidgets.widgets import ( - Button, - HBox, - Output, - SelectionRangeSlider, - SelectMultiple, -) +from ipywidgets.widgets import Button, HBox, SelectionRangeSlider, SelectMultiple from tempo_embeddings.text.corpus import Corpus from tempo_embeddings.visualization.jscatter import JScatterVisualizer @@ -43,8 +37,6 @@ def test_visualize( exception, ): widget_types = [HBox, HBox, HBox, Button, DownloadButton] - cat_types = [SelectionRangeSlider, Output] - cont_types = [SelectMultiple, Output] visualizer = JScatterVisualizer( [corpus], @@ -54,14 +46,19 @@ def test_visualize( with exception or does_not_raise(): visualizer.visualize() - widgets = mock_display.call_args.args + top_widgets = mock_display.call_args.args + """Top-level widgets passed to the display() call.""" - categorical_filters = widgets[1].children - continous_filters = widgets[2].children + categorical_filters = top_widgets[1].children + continous_filters = top_widgets[2].children - assert [type(w) for w in widgets] == widget_types - assert [type(w) for w in categorical_filters] == cat_types * expected_cat - assert [type(w) for w in continous_filters] == cont_types * expected_cont + assert [type(w) for w in top_widgets] == widget_types + assert [type(w) for w in categorical_filters] == [ + SelectionRangeSlider + ] * expected_cat + assert [type(w) for w in continous_filters] == [ + SelectMultiple + ] * expected_cont @pytest.mark.parametrize( "tooltip_fields,expected",