From 0876dedef5a6c73f0df9bda92118e9517a96704a Mon Sep 17 00:00:00 2001 From: Carsten Schnober Date: Wed, 18 Dec 2024 16:32:21 +0100 Subject: [PATCH 1/3] Add search filter. --- tempo_embeddings/visualization/jscatter.py | 45 ++++++++++++++++++++++ tests/unit/visualization/test_jscatter.py | 2 + 2 files changed, 47 insertions(+) diff --git a/tempo_embeddings/visualization/jscatter.py b/tempo_embeddings/visualization/jscatter.py index 0f64afa..8ad6e1a 100644 --- a/tempo_embeddings/visualization/jscatter.py +++ b/tempo_embeddings/visualization/jscatter.py @@ -534,6 +534,50 @@ def handle_slider_change(change): return selection + def _search_filter(self) -> widgets.HBox: + """Create a search widget for filtering on a field. + + Returns: + widgets.HBox: A widget containing a search box + """ + + search = widgets.Text( + description="Search:", + placeholder="Enter search term", + continuous_update=True, + ) + field_selector = widgets.Dropdown( + options=sorted( + [c for c in self._df.columns if self._df[c].dtype == "string"] + ), + value="text", + description="In field", + layout={"width": "max-content"}, + ) + _widgets = [search, field_selector] + + def handle_search_change(change): + for w in _widgets: + w.disabled = True + + search_term = search.value.strip() + if search_term: + filtered = self._df.loc[ + self._df[field_selector.value].str.contains(search_term) + ] + else: + filtered = self._df + + self._filter(field_selector, filtered.index) + + for w in _widgets: + w.disabled = False + + search.observe(handle_search_change, names="value") + field_selector.observe(handle_search_change, names="value") + + return widgets.HBox(_widgets) + def _filter(self, field, index): """Filter the scatter plot based on the given field and index. @@ -601,6 +645,7 @@ def get_widgets( widgets.HBox( [widget for widget in category_filters if widget is not None] ), + self._search_filter(), self._color_by_dropdown(), self._select_tooltips(), self._export_button(), diff --git a/tests/unit/visualization/test_jscatter.py b/tests/unit/visualization/test_jscatter.py index 1d014c5..98efc5f 100644 --- a/tests/unit/visualization/test_jscatter.py +++ b/tests/unit/visualization/test_jscatter.py @@ -40,6 +40,7 @@ def test_init(self, container): HBox, HBox, HBox, + HBox, Dropdown, SelectMultiple, HBox, @@ -139,6 +140,7 @@ def test_get_widgets( HBox, HBox, HBox, + HBox, Dropdown, SelectMultiple, HBox, From 303937acc14655aa07dd6a648d55f9465c2c1f06 Mon Sep 17 00:00:00 2001 From: Carsten Schnober Date: Wed, 18 Dec 2024 17:04:31 +0100 Subject: [PATCH 2/3] Add test_search_filter(). --- tests/unit/visualization/test_jscatter.py | 28 +++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/tests/unit/visualization/test_jscatter.py b/tests/unit/visualization/test_jscatter.py index 98efc5f..1dbe88d 100644 --- a/tests/unit/visualization/test_jscatter.py +++ b/tests/unit/visualization/test_jscatter.py @@ -2,6 +2,7 @@ from contextlib import nullcontext as does_not_raise from unittest import mock +import numpy as np import pandas as pd import pytest from ipywidgets.widgets import ( @@ -277,3 +278,30 @@ def test_select_tooltips(self, corpus): select_tooltips.value = ["provenance"] assert pw._scatter_plot.tooltip()["properties"] == ["provenance"] + + @pytest.mark.parametrize( + "search_term,field,expected", + [ + ("test", "text", np.arange(5)), + ("test", "provenance", np.arange(5)), + ("invalid", "text", []), + ], + ) + def test_search_filter(self, corpus, search_term, field, expected): + pw = JScatterVisualizer.PlotWidgets( + df=corpus.to_dataframe().convert_dtypes(), + color_by="corpus", + tooltip_fields=set(), + ) + + search_box = pw._search_filter() + assert isinstance(search_box, HBox) + assert [type(w) for w in search_box.children] == [Text, Dropdown] + + search_box.children[0].value = search_term + search_box.children[1].value = field + + np.testing.assert_equal(pw._scatter_plot.filter(), expected) + + search_box.children[0].value = "" + np.testing.assert_equal(pw._scatter_plot.filter(), np.arange(5)) From 70c2640bcc4688a39cc2ac5d6cda9735f2d66763 Mon Sep 17 00:00:00 2001 From: Carsten Schnober Date: Wed, 18 Dec 2024 17:07:32 +0100 Subject: [PATCH 3/3] Extract plot_widgets fixture. --- tests/unit/visualization/test_jscatter.py | 51 ++++++++++------------- 1 file changed, 21 insertions(+), 30 deletions(-) diff --git a/tests/unit/visualization/test_jscatter.py b/tests/unit/visualization/test_jscatter.py index 1dbe88d..b350226 100644 --- a/tests/unit/visualization/test_jscatter.py +++ b/tests/unit/visualization/test_jscatter.py @@ -223,7 +223,15 @@ def test_plot_button(self, corpus): class TestPlotWidgets: - def test_export_button(self, corpus, tmp_path): + @pytest.fixture + def plot_widgets(self, corpus): + return JScatterVisualizer.PlotWidgets( + df=corpus.to_dataframe().convert_dtypes(), + color_by="corpus", + tooltip_fields=set(), + ) + + def test_export_button(self, plot_widgets, tmp_path): expected_columns = [ "text", "ID_DB", @@ -238,10 +246,7 @@ def test_export_button(self, corpus, tmp_path): "distance_to_centroid", ] - pw = JScatterVisualizer.PlotWidgets( - df=corpus.to_dataframe(), color_by="corpus", tooltip_fields=set() - ) - export_button = pw._export_button() + export_button = plot_widgets._export_button() assert isinstance(export_button, HBox) assert [type(w) for w in export_button.children] == [Button, Text, Checkbox] @@ -254,30 +259,22 @@ def test_export_button(self, corpus, tmp_path): df = pd.read_csv(target_file) assert df.columns.to_list() == expected_columns - assert len(df) == len(corpus) - - def test_color_by(self, corpus): - pw = JScatterVisualizer.PlotWidgets( - df=corpus.to_dataframe(), color_by="corpus", tooltip_fields=set() - ) + assert len(df) == 5 - color_box = pw._color_by_dropdown() + def test_color_by(self, plot_widgets): + color_box = plot_widgets._color_by_dropdown() assert isinstance(color_box, Dropdown) assert color_box.value == "corpus" color_box.value = "year" - assert pw._scatter_plot.color()["by"] == "year" - - def test_select_tooltips(self, corpus): - pw = JScatterVisualizer.PlotWidgets( - df=corpus.to_dataframe(), color_by="corpus", tooltip_fields=set() - ) + assert plot_widgets._scatter_plot.color()["by"] == "year" - select_tooltips = pw._select_tooltips() + def test_select_tooltips(self, plot_widgets): + select_tooltips = plot_widgets._select_tooltips() assert isinstance(select_tooltips, SelectMultiple) select_tooltips.value = ["provenance"] - assert pw._scatter_plot.tooltip()["properties"] == ["provenance"] + assert plot_widgets._scatter_plot.tooltip()["properties"] == ["provenance"] @pytest.mark.parametrize( "search_term,field,expected", @@ -287,21 +284,15 @@ def test_select_tooltips(self, corpus): ("invalid", "text", []), ], ) - def test_search_filter(self, corpus, search_term, field, expected): - pw = JScatterVisualizer.PlotWidgets( - df=corpus.to_dataframe().convert_dtypes(), - color_by="corpus", - tooltip_fields=set(), - ) - - search_box = pw._search_filter() + def test_search_filter(self, plot_widgets, search_term, field, expected): + search_box = plot_widgets._search_filter() assert isinstance(search_box, HBox) assert [type(w) for w in search_box.children] == [Text, Dropdown] search_box.children[0].value = search_term search_box.children[1].value = field - np.testing.assert_equal(pw._scatter_plot.filter(), expected) + np.testing.assert_equal(plot_widgets._scatter_plot.filter(), expected) search_box.children[0].value = "" - np.testing.assert_equal(pw._scatter_plot.filter(), np.arange(5)) + np.testing.assert_equal(plot_widgets._scatter_plot.filter(), np.arange(5))